From 2b156fffb8f5bb1924fcf2a7aa0713ddebf2ea4f Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 22 Apr 2024 18:06:22 -0500 Subject: [PATCH 1/8] Cleaned up commented code and switched off sanitizer by default in CMakeLists --- CMakeLists.txt | 2 +- include/stochtree/prior.h | 176 -------------------------------------- 2 files changed, 1 insertion(+), 177 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index cf931b2e..40c8fe84 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,6 @@ # Build options option(USE_DEBUG "Set to ON for Debug mode" OFF) -option(USE_SANITIZER "Use santizer flags" ON) +option(USE_SANITIZER "Use santizer flags" OFF) option(BUILD_TEST "Build C++ tests with Google Test" ON) option(BUILD_DEBUG_TARGETS "Build Standalone C++ Programs for Debugging" ON) diff --git a/include/stochtree/prior.h b/include/stochtree/prior.h index c2bfc7b2..af095d2f 100644 --- a/include/stochtree/prior.h +++ b/include/stochtree/prior.h @@ -10,62 +10,6 @@ namespace StochTree { -// class LeafGaussianPrior { -// public: -// LeafGaussianPrior() {} -// virtual ~LeafGaussianPrior() = default; -// }; - -// class LeafConstantGaussianPrior : public LeafGaussianPrior { -// public: -// LeafConstantGaussianPrior(double mu_bar, double tau) { -// mu_bar_ = mu_bar; -// tau_ = tau; -// } -// ~LeafConstantGaussianPrior() {} -// double GetPriorMean() {return mu_bar_;} -// double GetPriorScale() {return tau_;} -// void SetPriorMean(double mu_bar) {mu_bar_ = mu_bar;} -// void SetPriorScale(double tau) {tau_ = tau;} -// private: -// double mu_bar_; -// double tau_; -// }; - -// class LeafUnivariateRegressionGaussianPrior : public LeafGaussianPrior { -// public: -// LeafUnivariateRegressionGaussianPrior(double beta_bar, double tau) { -// beta_bar_ = beta_bar; -// tau_ = tau; -// } -// ~LeafUnivariateRegressionGaussianPrior() {} -// double GetPriorMean() {return beta_bar_;} -// double GetPriorScale() {return tau_;} -// void SetPriorMean(double beta_bar) {beta_bar_ = beta_bar;} -// void SetPriorScale(double tau) {tau_ = tau;} -// private: -// double beta_bar_; -// double tau_; -// }; - -// class LeafMultivariateRegressionGaussianPrior : public LeafGaussianPrior { -// public: -// LeafMultivariateRegressionGaussianPrior(Eigen::VectorXd& Beta, Eigen::MatrixXd& Sigma, int basis_dim) { -// Beta_ = Beta; -// Sigma_ = Sigma; -// basis_dim_ = basis_dim; -// } -// ~LeafMultivariateRegressionGaussianPrior() {} -// Eigen::VectorXd GetPriorMean() {return Beta_;} -// Eigen::MatrixXd GetPriorScale() {return Sigma_;} -// void SetPriorMean(Eigen::VectorXd& Beta) {Beta_ = Beta;} -// void SetPriorScale(Eigen::MatrixXd& Sigma) {Sigma_ = Sigma;} -// private: -// Eigen::VectorXd Beta_; -// Eigen::MatrixXd Sigma_; -// int basis_dim_; -// }; - class RandomEffectsGaussianPrior { public: RandomEffectsGaussianPrior() {} @@ -132,126 +76,6 @@ class IGVariancePrior { double scale_; }; -// /*! \brief Sufficient statistic and associated operations for gaussian homoskedastic constant leaf outcome model */ -// struct LeafConstantGaussianSuffStat { -// data_size_t n; -// double sum_y; -// double sum_y_squared; -// LeafConstantGaussianSuffStat() { -// n = 0; -// sum_y = 0.0; -// sum_y_squared = 0.0; -// } -// template -// void IncrementSuffStat(LeafForestDatasetType* data, UnivariateResidual* residual, data_size_t row_idx) { -// n += 1; -// sum_y += residual->residual(row_idx, 0); -// sum_y_squared += std::pow(residual->residual(row_idx, 0), 2.0); -// } -// void ResetSuffStat() { -// n = 0; -// sum_y = 0.0; -// sum_y_squared = 0.0; -// } -// void SubtractSuffStat(LeafConstantGaussianSuffStat& lhs, LeafConstantGaussianSuffStat& rhs) { -// n = lhs.n - rhs.n; -// sum_y = lhs.sum_y - rhs.sum_y; -// sum_y_squared = lhs.sum_y_squared - rhs.sum_y_squared; -// } -// bool SampleGreaterThan(data_size_t threshold) { -// return n > threshold; -// } -// data_size_t SampleSize() { -// return n; -// } -// }; - -// /*! \brief Sufficient statistic and associated operations for homoskedastic, univariate regression leaf outcome model */ -// struct LeafUnivariateRegressionGaussianSuffStat { -// data_size_t n; -// double sum_y; -// double sum_yx; -// double sum_x_squared; -// double sum_y_squared; -// LeafUnivariateRegressionGaussianSuffStat() { -// n = 0; -// sum_y = 0.0; -// sum_yx = 0.0; -// sum_x_squared = 0.0; -// sum_y_squared = 0.0; -// } -// template -// void IncrementSuffStat(RegressionLeafForestDataset* data, UnivariateResidual* residual, data_size_t row_idx) { -// n += 1; -// sum_y += residual->residual(row_idx, 0); -// sum_yx += residual->residual(row_idx, 0)*data->basis(row_idx, 0); -// sum_x_squared += std::pow(data->basis(row_idx, 0), 2.0); -// sum_y_squared += std::pow(residual->residual(row_idx, 0), 2.0); -// } -// void ResetSuffStat() { -// n = 0; -// sum_y = 0.0; -// sum_yx = 0.0; -// sum_x_squared = 0.0; -// sum_y_squared = 0.0; -// } -// void SubtractSuffStat(LeafUnivariateRegressionGaussianSuffStat& lhs, LeafUnivariateRegressionGaussianSuffStat& rhs) { -// n = lhs.n - rhs.n; -// sum_y = lhs.sum_y - rhs.sum_y; -// sum_yx = lhs.sum_yx - rhs.sum_yx; -// sum_x_squared = lhs.sum_x_squared - rhs.sum_x_squared; -// sum_y_squared = lhs.sum_y_squared - rhs.sum_y_squared; -// } -// bool SampleGreaterThan(data_size_t threshold) { -// return n > threshold; -// } -// data_size_t SampleSize() { -// return n; -// } -// }; - -// /*! \brief Sufficient statistic and associated operations for gaussian homoskedastic multivariate regression leaf outcome model */ -// struct LeafMultivariateRegressionGaussianSuffStat { -// data_size_t n; -// int basis_dim; -// double yty; -// Eigen::MatrixXd Xty; -// Eigen::MatrixXd XtX; -// LeafMultivariateRegressionGaussianSuffStat(int basis_dim) { -// basis_dim = basis_dim; -// n = 0; -// yty = 0.0; -// Xty = Eigen::MatrixXd::Zero(basis_dim, 1); -// XtX = Eigen::MatrixXd::Zero(basis_dim, basis_dim); -// } -// template -// void IncrementSuffStat(LeafForestDatasetType* data, UnivariateResidual* residual, data_size_t row_idx) { -// CHECK_EQ(basis_dim, data->basis.cols()); -// n += 1; -// yty += std::pow(residual->residual(row_idx, 0), 2.0); -// Xty += data->basis(row_idx, Eigen::all).transpose()*residual->residual(row_idx, 0); -// XtX += data->basis(row_idx, Eigen::all).transpose()*data->basis(row_idx, Eigen::all); -// } -// void ResetSuffStat() { -// n = 0; -// yty = 0.0; -// Xty = Eigen::MatrixXd::Zero(basis_dim, 1); -// XtX = Eigen::MatrixXd::Zero(basis_dim, basis_dim); -// } -// void SubtractSuffStat(LeafMultivariateRegressionGaussianSuffStat& lhs, LeafMultivariateRegressionGaussianSuffStat& rhs) { -// n = lhs.n - rhs.n; -// yty = lhs.yty - rhs.yty; -// Xty = lhs.Xty - rhs.Xty; -// XtX = lhs.XtX - rhs.XtX; -// } -// bool SampleGreaterThan(data_size_t threshold) { -// return n > threshold; -// } -// data_size_t SampleSize() { -// return n; -// } -// }; - } // namespace StochTree #endif // STOCHTREE_PRIOR_H_ \ No newline at end of file From 71b7aaa05c86cce32757d0c0c11e5b4760c95cfe Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Tue, 23 Apr 2024 00:43:52 -0500 Subject: [PATCH 2/8] Added BCF debug program, investigating memory usage --- CMakeLists.txt | 9 +- debug/bcf_debug.cpp | 393 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 400 insertions(+), 2 deletions(-) create mode 100644 debug/bcf_debug.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 40c8fe84..4705794f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -107,10 +107,15 @@ if(BUILD_TEST) endif() if(BUILD_DEBUG_TARGETS) - # Build test suite - add_executable(debugstochtree debug/api_debug.cpp) + # Debug basic stochtree interface set(StochTree_DEBUG_HEADER_DIR ${PROJECT_SOURCE_DIR}/cpp) + add_executable(debugstochtree debug/api_debug.cpp) target_include_directories(debugstochtree PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${StochTree_DEBUG_HEADER_DIR}) target_link_libraries(debugstochtree PRIVATE stochtree_objs) + + # Debug BCF interface + add_executable(debugbcf debug/bcf_debug.cpp) + target_include_directories(debugbcf PRIVATE ${StochTree_HEADER_DIR} ${BOOSTMATH_HEADER_DIR} ${EIGEN_HEADER_DIR} ${StochTree_DEBUG_HEADER_DIR}) + target_link_libraries(debugbcf PRIVATE stochtree_objs) endif() diff --git a/debug/bcf_debug.cpp b/debug/bcf_debug.cpp new file mode 100644 index 00000000..612d0495 --- /dev/null +++ b/debug/bcf_debug.cpp @@ -0,0 +1,393 @@ +/*! Copyright (c) 2024 stochtree authors*/ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace StochTree{ + +enum ForestLeafModel { + kConstant, + kUnivariateRegression, + kMultivariateRegression +}; + +double calibrate_lambda(ForestDataset& covariates, ColumnVector& residual, double nu, double q) { + // Linear model of residual ~ covariates + double n = static_cast(covariates.NumObservations()); + Eigen::MatrixXd X = covariates.GetCovariates(); + Eigen::VectorXd y = residual.GetData(); + Eigen::VectorXd beta = (X.transpose() * X).inverse() * (X.transpose() * y); + double sum_sq_resid = (y - X * beta).transpose() * (y - X * beta); + double sigma_hat = sum_sq_resid / n; + + // Compute implied lambda + return (sigma_hat * boost::math::gamma_q_inv(nu, q)) / nu; +} + +double std_gaussian_cdf(double x) { + return 0.5*(1 + std::erf(x/std::sqrt(2))); +} + +double g(double x1, double x2, double x3, double x4, double x5) { + double output; + if (std::abs(x5-0.0) < 0.001) {output = 2.0;} + else if (std::abs(x5-1.) < 0.001) {output = -1.0;} + else {output = -4.0;} + return output; +} + +double mu1(std::vector& covariates, int n, int x_cols, int i) { + CHECK_GE(x_cols, 5); + CHECK_GT(n, i); + double x1, x2, x3, x4, x5; + x1 = covariates[i*x_cols + 0]; + x2 = covariates[i*x_cols + 1]; + x3 = covariates[i*x_cols + 2]; + x4 = covariates[i*x_cols + 3]; + x5 = covariates[i*x_cols + 4]; + return 1.0 + g(x1,x2,x3,x4,x5) + x1*x3; +} + +double mu2(std::vector& covariates, int n, int x_cols, int i) { + CHECK_GE(x_cols, 5); + CHECK_GT(n, i); + double x1, x2, x3, x4, x5; + x1 = covariates[i*x_cols + 0]; + x2 = covariates[i*x_cols + 1]; + x3 = covariates[i*x_cols + 2]; + x4 = covariates[i*x_cols + 3]; + x5 = covariates[i*x_cols + 4]; + return 1.0 + g(x1,x2,x3,x4,x5) + 6.0*std::abs(x3-1); +} + +double tau1(std::vector& covariates, int n, int x_cols, int i) { + return 3; +} + +double tau2(std::vector& covariates, int n, int x_cols, int i) { + CHECK_GE(x_cols, 5); + CHECK_GT(n, i); + double x1, x2, x3, x4, x5; + x1 = covariates[i*x_cols + 0]; + x2 = covariates[i*x_cols + 1]; + x3 = covariates[i*x_cols + 2]; + x4 = covariates[i*x_cols + 3]; + x5 = covariates[i*x_cols + 4]; + return 1 + 2.0*x2*x4; +} + +void GenerateRandomData(std::vector& covariates, std::vector& propensity, std::vector& treatment, std::vector& outcome, + std::function&, int, int, int)> mu, std::function&, int, int, int)> tau, + int n, int x_cols, double snr = 2.0, int random_seed = -1) { + std::mt19937 gen; + if (random_seed == 1) { + std::random_device rd; + gen = std::mt19937(rd()); + } else { + gen = std::mt19937(random_seed); + } + std::uniform_real_distribution std_uniform_dist{0.0,1.0}; + std::normal_distribution std_normal_dist(0.,1.); + std::discrete_distribution<> binary_covariate_dist({50, 50}); + std::discrete_distribution<> categorical_covariate_dist({33,33,33}); + std::vector mu_x(n); + std::vector tau_x(n); + std::vector f_x_z(n); + + CHECK_GE(x_cols, 5); + double x_val; + for (int i = 0; i < n; i++) { + // Covariates + for (int j = 0; j < x_cols; j++) { + if (j == 3) { + x_val = categorical_covariate_dist(gen); + } else if (j == 4) { + x_val = binary_covariate_dist(gen); + } else { + x_val = std_normal_dist(gen); + } + covariates[i*x_cols + j] = x_val; + } + + // Prognostic function + mu_x[i] = mu(covariates, n, x_cols, i); + + // Treatment effect + tau_x[i] = tau(covariates, n, x_cols, i); + } + + // Compute mean and sd of mu_x + double mu_sum = std::accumulate(mu_x.begin(), mu_x.end(), 0.0); + double mu_mean = mu_sum / static_cast(n); + double mu_sum_squares = std::accumulate(mu_x.begin(), mu_x.end(), 0.0, [](double a, double b){return a + b*b;}); + double mu_stddev = std::sqrt((mu_sum_squares / static_cast(n)) - mu_mean*mu_mean); + + for (int i = 0; i < n; i++) { + // Propensity score + propensity[i] = 0.8*std_gaussian_cdf((3*mu_x[i]/mu_stddev) - 0.5*covariates[i * x_cols + 0]) + 0.05 + std_uniform_dist(gen)/10.; + + // Treatment + treatment[i] = (std_uniform_dist(gen) < propensity[i]) ? 1.0 : 0.0; + + // Expected outcome + f_x_z[i] = mu_x[i] + tau_x[i] * treatment[i]; + } + + // Compute sd(E(Y | X, Z)) + double ey_sum = std::accumulate(f_x_z.begin(), f_x_z.end(), 0.0); + double ey_mean = ey_sum / static_cast(n); + double ey_sum_squares = std::accumulate(f_x_z.begin(), f_x_z.end(), 0.0, [](double a, double b){return a + b*b;}); + double ey_stddev = std::sqrt((ey_sum_squares / static_cast(n)) - ey_mean*ey_mean); + + for (int i = 0; i < n; i++) { + // Propensity score + outcome[i] = f_x_z[i] + (ey_stddev/snr)*std_normal_dist(gen); + } +} + +void OutcomeOffsetScale(ColumnVector& residual, double& outcome_offset, double& outcome_scale) { + data_size_t n = residual.NumRows(); + double outcome_val = 0.0; + double outcome_sum = 0.0; + double outcome_sum_squares = 0.0; + double var_y = 0.0; + for (data_size_t i = 0; i < n; i++){ + outcome_val = residual.GetElement(i); + outcome_sum += outcome_val; + outcome_sum_squares += std::pow(outcome_val, 2.0); + } + var_y = outcome_sum_squares / static_cast(n) - std::pow(outcome_sum / static_cast(n), 2.0); + outcome_scale = std::sqrt(var_y); + outcome_offset = outcome_sum / static_cast(n); + double previous_residual; + for (data_size_t i = 0; i < n; i++){ + previous_residual = residual.GetElement(i); + residual.SetElement(i, (previous_residual - outcome_offset) / outcome_scale); + } +} + +void sampleGFR(ForestTracker& tracker, TreePrior& tree_prior, ForestContainer& forest_samples, ForestDataset& dataset, + ColumnVector& residual, std::mt19937& rng, std::vector& feature_types, std::vector& var_weights_vector, + ForestLeafModel leaf_model_type, Eigen::MatrixXd& leaf_scale_matrix, double global_variance, double leaf_scale, int cutpoint_grid_size) { + if (leaf_model_type == ForestLeafModel::kConstant) { + GaussianConstantLeafModel leaf_model = GaussianConstantLeafModel(leaf_scale); + GFRForestSampler sampler = GFRForestSampler(cutpoint_grid_size); + sampler.SampleOneIter(tracker, forest_samples, leaf_model, dataset, residual, tree_prior, rng, var_weights_vector, global_variance, feature_types); + } else if (leaf_model_type == ForestLeafModel::kUnivariateRegression) { + GaussianUnivariateRegressionLeafModel leaf_model = GaussianUnivariateRegressionLeafModel(leaf_scale); + GFRForestSampler sampler = GFRForestSampler(cutpoint_grid_size); + sampler.SampleOneIter(tracker, forest_samples, leaf_model, dataset, residual, tree_prior, rng, var_weights_vector, global_variance, feature_types); + } else if (leaf_model_type == ForestLeafModel::kMultivariateRegression) { + GaussianMultivariateRegressionLeafModel leaf_model = GaussianMultivariateRegressionLeafModel(leaf_scale_matrix); + GFRForestSampler sampler = GFRForestSampler(cutpoint_grid_size); + sampler.SampleOneIter(tracker, forest_samples, leaf_model, dataset, residual, tree_prior, rng, var_weights_vector, global_variance, feature_types); + } +} + +void sampleMCMC(ForestTracker& tracker, TreePrior& tree_prior, ForestContainer& forest_samples, ForestDataset& dataset, + ColumnVector& residual, std::mt19937& rng, std::vector& feature_types, std::vector& var_weights_vector, + ForestLeafModel leaf_model_type, Eigen::MatrixXd& leaf_scale_matrix, double global_variance, double leaf_scale, int cutpoint_grid_size) { + if (leaf_model_type == ForestLeafModel::kConstant) { + GaussianConstantLeafModel leaf_model = GaussianConstantLeafModel(leaf_scale); + MCMCForestSampler sampler = MCMCForestSampler(); + sampler.SampleOneIter(tracker, forest_samples, leaf_model, dataset, residual, tree_prior, rng, var_weights_vector, global_variance); + } else if (leaf_model_type == ForestLeafModel::kUnivariateRegression) { + GaussianUnivariateRegressionLeafModel leaf_model = GaussianUnivariateRegressionLeafModel(leaf_scale); + MCMCForestSampler sampler = MCMCForestSampler(); + sampler.SampleOneIter(tracker, forest_samples, leaf_model, dataset, residual, tree_prior, rng, var_weights_vector, global_variance); + } else if (leaf_model_type == ForestLeafModel::kMultivariateRegression) { + GaussianMultivariateRegressionLeafModel leaf_model = GaussianMultivariateRegressionLeafModel(leaf_scale_matrix); + MCMCForestSampler sampler = MCMCForestSampler(); + sampler.SampleOneIter(tracker, forest_samples, leaf_model, dataset, residual, tree_prior, rng, var_weights_vector, global_variance); + } +} + +void RunAPI() { + // Data dimensions + int n = 1000; + int x_cols = 5; + + // Declare covariates, basis and outcome + std::vector covariates_raw(n*x_cols); + std::vector propensity_raw(n); + std::vector treatment_raw(n); + std::vector outcome_raw(n); + + // Load the data + GenerateRandomData(covariates_raw, propensity_raw, treatment_raw, outcome_raw, mu1, tau2, n, x_cols, 2.0, 1); + + // Add pi_x as a column to a covariate set used in the prognostic forest + std::vector covariates_pi(n*(x_cols+1)); + for (int i = 0; i < n; i++) { + for (int j = 0; j < x_cols; j++) { + covariates_pi[i*(x_cols+1) + j] = covariates_raw[i*x_cols + j]; + } + covariates_pi[i*(x_cols+1) + x_cols] = propensity_raw[i]; + } + + // Define internal datasets + bool row_major = true; + + // Construct datasets for training, include pi(x) as a covariate in the prognostic forest + ForestDataset tau_dataset = ForestDataset(); + tau_dataset.AddCovariates(covariates_raw.data(), n, x_cols, row_major); + tau_dataset.AddBasis(treatment_raw.data(), n, 1, row_major); + ForestDataset mu_dataset = ForestDataset(); + mu_dataset.AddCovariates(covariates_pi.data(), n, x_cols+1, row_major); + ColumnVector residual = ColumnVector(outcome_raw.data(), n); + + // Center and scale the data + double outcome_offset; + double outcome_scale; + OutcomeOffsetScale(residual, outcome_offset, outcome_scale); + + // Initialize ensembles for prognostic and treatment forests + int num_trees_mu = 200; + int num_trees_tau = 50; + ForestContainer forest_samples_mu = ForestContainer(num_trees_mu, 1, true); + ForestContainer forest_samples_tau = ForestContainer(num_trees_tau, 1, false); + + // Initialize leaf models for mu and tau forests + double leaf_prior_scale_mu = (outcome_scale*outcome_scale)/num_trees_mu; + double leaf_prior_scale_tau = (outcome_scale*outcome_scale)/(2*num_trees_tau); + GaussianConstantLeafModel leaf_model_mu = GaussianConstantLeafModel(leaf_prior_scale_mu); + GaussianUnivariateRegressionLeafModel leaf_model_tau = GaussianUnivariateRegressionLeafModel(leaf_prior_scale_tau); + + // Initialize forest sampling machinery + std::vector feature_types_mu(x_cols + 1, FeatureType::kNumeric); + feature_types_mu[3] = FeatureType::kOrderedCategorical; + feature_types_mu[4] = FeatureType::kOrderedCategorical; + std::vector feature_types_tau(x_cols, FeatureType::kNumeric); + feature_types_tau[3] = FeatureType::kOrderedCategorical; + feature_types_tau[4] = FeatureType::kOrderedCategorical; + double alpha_mu = 0.95; + double alpha_tau = 0.25; + double beta_mu = 2.0; + double beta_tau = 3.0; + int min_samples_leaf_mu = 5; + int min_samples_leaf_tau = 5; + int cutpoint_grid_size_mu = 100; + int cutpoint_grid_size_tau = 100; + double a_leaf_mu = 3.; + double b_leaf_mu = leaf_prior_scale_mu; + double a_leaf_tau = 3.; + double b_leaf_tau = leaf_prior_scale_tau; + double nu = 3.; + double lamb = calibrate_lambda(tau_dataset, residual, nu, 0.9); + ForestLeafModel leaf_model_type_mu = ForestLeafModel::kConstant; + ForestLeafModel leaf_model_type_tau = ForestLeafModel::kUnivariateRegression; + + // Set leaf model parameters + double leaf_scale_mu; + double leaf_scale_tau = leaf_prior_scale_tau; + Eigen::MatrixXd leaf_scale_matrix_mu; + Eigen::MatrixXd leaf_scale_matrix_tau; + + // Set global variance + double global_variance_init = 1.0; + double global_variance; + + // Set variable weights + double const_var_wt_mu = static_cast(1/(x_cols+1)); + std::vector variable_weights_mu(x_cols+1, const_var_wt_mu); + double const_var_wt_tau = static_cast(1/x_cols); + std::vector variable_weights_tau(x_cols, const_var_wt_tau); + + // Initialize tracker and tree prior + ForestTracker mu_tracker = ForestTracker(mu_dataset.GetCovariates(), feature_types_mu, num_trees_mu, n); + ForestTracker tau_tracker = ForestTracker(tau_dataset.GetCovariates(), feature_types_tau, num_trees_tau, n); + TreePrior tree_prior_mu = TreePrior(alpha_mu, beta_mu, min_samples_leaf_mu); + TreePrior tree_prior_tau = TreePrior(alpha_tau, beta_tau, min_samples_leaf_tau); + + // Initialize a random number generator + std::random_device rd; + std::mt19937 rng = std::mt19937(rd()); + + // Initialize variance models + GlobalHomoskedasticVarianceModel global_var_model = GlobalHomoskedasticVarianceModel(); + LeafNodeHomoskedasticVarianceModel leaf_var_model_mu = LeafNodeHomoskedasticVarianceModel(); + + // Initialize storage for samples of variance + std::vector global_variance_samples{}; + std::vector leaf_variance_samples_mu{}; + + // Run the GFR sampler + int num_gfr_samples = 10; + for (int i = 0; i < num_gfr_samples; i++) { + if (i == 0) { + global_variance = global_variance_init; + leaf_scale_mu = leaf_prior_scale_mu; + } else { + global_variance = global_variance_samples[i-1]; + leaf_scale_mu = leaf_variance_samples_mu[i-1]; + } + + // Sample mu ensemble + sampleGFR(mu_tracker, tree_prior_mu, forest_samples_mu, mu_dataset, residual, rng, feature_types_mu, variable_weights_mu, + leaf_model_type_mu, leaf_scale_matrix_mu, global_variance, leaf_scale_mu, cutpoint_grid_size_mu); + + // Sample leaf node variance + leaf_variance_samples_mu.push_back(leaf_var_model_mu.SampleVarianceParameter(forest_samples_mu.GetEnsemble(i), a_leaf_mu, b_leaf_mu, rng)); + + // Sample global variance + global_variance_samples.push_back(global_var_model.SampleVarianceParameter(residual.GetData(), nu, nu*lamb, rng)); + + // Sample tau ensemble + sampleGFR(tau_tracker, tree_prior_tau, forest_samples_tau, tau_dataset, residual, rng, feature_types_tau, variable_weights_tau, + leaf_model_type_tau, leaf_scale_matrix_tau, global_variance, leaf_scale_tau, cutpoint_grid_size_tau); + + // Sample global variance + global_variance_samples.push_back(global_var_model.SampleVarianceParameter(residual.GetData(), nu, nu*lamb, rng)); + } + + // Run the MCMC sampler + int num_mcmc_samples = 10000; + for (int i = num_gfr_samples; i < num_gfr_samples + num_mcmc_samples; i++) { + if (i == 0) { + global_variance = global_variance_init; + leaf_scale_mu = leaf_prior_scale_mu; + } else { + global_variance = global_variance_samples[i-1]; + leaf_scale_mu = leaf_variance_samples_mu[i-1]; + } + + // Sample mu ensemble + sampleMCMC(mu_tracker, tree_prior_mu, forest_samples_mu, mu_dataset, residual, rng, feature_types_mu, variable_weights_mu, + leaf_model_type_mu, leaf_scale_matrix_mu, global_variance, leaf_scale_mu, cutpoint_grid_size_mu); + + // Sample leaf node variance + leaf_variance_samples_mu.push_back(leaf_var_model_mu.SampleVarianceParameter(forest_samples_mu.GetEnsemble(i), a_leaf_mu, b_leaf_mu, rng)); + + // Sample global variance + global_variance_samples.push_back(global_var_model.SampleVarianceParameter(residual.GetData(), nu, nu*lamb, rng)); + + // Sample tau ensemble + sampleMCMC(tau_tracker, tree_prior_tau, forest_samples_tau, tau_dataset, residual, rng, feature_types_tau, variable_weights_tau, + leaf_model_type_tau, leaf_scale_matrix_tau, global_variance, leaf_scale_tau, cutpoint_grid_size_tau); + + // Sample global variance + global_variance_samples.push_back(global_var_model.SampleVarianceParameter(residual.GetData(), nu, nu*lamb, rng)); + } +} + +} // namespace StochTree + +int main() { + StochTree::RunAPI(); +} From 4ff1e72f3ba694642b4badc5e3d36b634858cb5e Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Sun, 28 Apr 2024 23:49:14 -0500 Subject: [PATCH 3/8] Updated to use Eigen::Map in the ForestDataset class (thus allowing zero-copy reuse of R / Python data) --- debug/api_debug.cpp | 6 +- debug/bcf_debug.cpp | 7 +- include/stochtree/cutpoint_candidates.h | 24 +-- include/stochtree/data.h | 48 +++-- include/stochtree/ensemble.h | 10 +- include/stochtree/meta.h | 9 + include/stochtree/partition_tracker.h | 45 ++-- include/stochtree/random_effects.h | 2 +- include/stochtree/tree.h | 10 +- src/cutpoint_candidates.cpp | 12 +- src/data.cpp | 55 +---- src/leaf_model.cpp | 6 +- src/partition_tracker.cpp | 21 +- src/tree.cpp | 5 +- test/testutils.cpp | 273 +++++++++++------------- test/testutils.h | 9 +- 16 files changed, 245 insertions(+), 297 deletions(-) diff --git a/debug/api_debug.cpp b/debug/api_debug.cpp index ec383936..ba335085 100644 --- a/debug/api_debug.cpp +++ b/debug/api_debug.cpp @@ -38,15 +38,15 @@ void GenerateRandomData(std::vector& covariates, std::vector& ba for (int i = 0; i < n; i++) { for (int j = 0; j < x_cols; j++) { - covariates[i*x_cols + j] = uniform_dist(gen); + covariates[j*n + i] = uniform_dist(gen); } for (int j = 0; j < omega_cols; j++) { - basis[i*omega_cols + j] = uniform_dist(gen); + basis[j*n + i] = uniform_dist(gen); } for (int j = 0; j < rfx_basis_cols; j++) { - rfx_basis[i*rfx_basis_cols + j] = 1; + rfx_basis[j*n + i] = 1; } if (i % 2 == 0) { diff --git a/debug/bcf_debug.cpp b/debug/bcf_debug.cpp index 612d0495..1ef35f2d 100644 --- a/debug/bcf_debug.cpp +++ b/debug/bcf_debug.cpp @@ -122,7 +122,8 @@ void GenerateRandomData(std::vector& covariates, std::vector& pr } else { x_val = std_normal_dist(gen); } - covariates[i*x_cols + j] = x_val; + // Store in column-major format + covariates[j*n + i] = x_val; } // Prognostic function @@ -242,7 +243,7 @@ void RunAPI() { } // Define internal datasets - bool row_major = true; + bool row_major = false; // Construct datasets for training, include pi(x) as a covariate in the prognostic forest ForestDataset tau_dataset = ForestDataset(); @@ -383,6 +384,8 @@ void RunAPI() { // Sample global variance global_variance_samples.push_back(global_var_model.SampleVarianceParameter(residual.GetData(), nu, nu*lamb, rng)); + + // Estimatw } } diff --git a/include/stochtree/cutpoint_candidates.h b/include/stochtree/cutpoint_candidates.h index 8c19013a..098ffa74 100644 --- a/include/stochtree/cutpoint_candidates.h +++ b/include/stochtree/cutpoint_candidates.h @@ -56,16 +56,16 @@ class FeatureCutpointGrid { ~FeatureCutpointGrid() {} /*! \brief Calculate strides */ - void CalculateStrides(Eigen::MatrixXd& covariates, Eigen::VectorXd& residuals, SortedNodeSampleTracker* feature_node_sort_tracker, int32_t node_id, data_size_t node_begin, data_size_t node_end, int32_t feature_index, std::vector& feature_types); + void CalculateStrides(MatrixMap& covariates, Eigen::VectorXd& residuals, SortedNodeSampleTracker* feature_node_sort_tracker, int32_t node_id, data_size_t node_begin, data_size_t node_end, int32_t feature_index, std::vector& feature_types); /*! \brief Split numeric / ordered categorical feature and update sort indices */ - void CalculateStridesNumeric(Eigen::MatrixXd& covariates, Eigen::VectorXd& residuals, SortedNodeSampleTracker* feature_node_sort_tracker, int32_t node_id, data_size_t node_begin, data_size_t node_end, int32_t feature_index); + void CalculateStridesNumeric(MatrixMap& covariates, Eigen::VectorXd& residuals, SortedNodeSampleTracker* feature_node_sort_tracker, int32_t node_id, data_size_t node_begin, data_size_t node_end, int32_t feature_index); /*! \brief Split numeric / ordered categorical feature and update sort indices */ - void CalculateStridesOrderedCategorical(Eigen::MatrixXd& covariates, Eigen::VectorXd& residuals, SortedNodeSampleTracker* feature_node_sort_tracker, int32_t node_id, data_size_t node_begin, data_size_t node_end, int32_t feature_index); + void CalculateStridesOrderedCategorical(MatrixMap& covariates, Eigen::VectorXd& residuals, SortedNodeSampleTracker* feature_node_sort_tracker, int32_t node_id, data_size_t node_begin, data_size_t node_end, int32_t feature_index); /*! \brief Split unordered categorical feature and update sort indices */ - void CalculateStridesUnorderedCategorical(Eigen::MatrixXd& covariates, Eigen::VectorXd& residuals, SortedNodeSampleTracker* feature_node_sort_tracker, int32_t node_id, data_size_t node_begin, data_size_t node_end, int32_t feature_index); + void CalculateStridesUnorderedCategorical(MatrixMap& covariates, Eigen::VectorXd& residuals, SortedNodeSampleTracker* feature_node_sort_tracker, int32_t node_id, data_size_t node_begin, data_size_t node_end, int32_t feature_index); /*! \brief Number of potential cutpoints enumerated */ int32_t NumCutpoints() {return node_stride_begin_.size();} @@ -102,16 +102,16 @@ class FeatureCutpointGrid { int32_t cutpoint_grid_size_; /*! \brief Full enumeration of numeric cutpoints, checking for duplicate value */ - void EnumerateNumericCutpointsDeduplication(Eigen::MatrixXd& covariates, Eigen::VectorXd& residuals, SortedNodeSampleTracker* feature_node_sort_tracker, int32_t node_id, data_size_t node_begin, data_size_t node_end, data_size_t node_size, int32_t feature_index); + void EnumerateNumericCutpointsDeduplication(MatrixMap& covariates, Eigen::VectorXd& residuals, SortedNodeSampleTracker* feature_node_sort_tracker, int32_t node_id, data_size_t node_begin, data_size_t node_end, data_size_t node_size, int32_t feature_index); /*! \brief Calculation of numeric cutpoints, thinning out to ensure that, at most, cutpoint_grid_size_ cutpoints are considered */ - void ScanNumericCutpoints(Eigen::MatrixXd& covariates, Eigen::VectorXd& residuals, SortedNodeSampleTracker* feature_node_sort_tracker, int32_t node_id, data_size_t node_begin, data_size_t node_end, data_size_t node_size, int32_t feature_index); + void ScanNumericCutpoints(MatrixMap& covariates, Eigen::VectorXd& residuals, SortedNodeSampleTracker* feature_node_sort_tracker, int32_t node_id, data_size_t node_begin, data_size_t node_end, data_size_t node_size, int32_t feature_index); }; /*! \brief Container class for FeatureCutpointGrid objects stored for every feature in a dataset */ class CutpointGridContainer { public: - CutpointGridContainer(Eigen::MatrixXd& covariates, Eigen::VectorXd& residuals, int cutpoint_grid_size) { + CutpointGridContainer(MatrixMap& covariates, Eigen::VectorXd& residuals, int cutpoint_grid_size) { num_features_ = covariates.cols(); feature_cutpoint_grid_.resize(num_features_); for (int i = 0; i < num_features_; i++) { @@ -122,7 +122,7 @@ class CutpointGridContainer { ~CutpointGridContainer() {} - void Reset(Eigen::MatrixXd& covariates, Eigen::VectorXd& residuals, int cutpoint_grid_size) { + void Reset(MatrixMap& covariates, Eigen::VectorXd& residuals, int cutpoint_grid_size) { num_features_ = covariates.cols(); feature_cutpoint_grid_.resize(num_features_); for (int i = 0; i < num_features_; i++) { @@ -132,7 +132,7 @@ class CutpointGridContainer { } /*! \brief Calculate strides */ - void CalculateStrides(Eigen::MatrixXd& covariates, Eigen::VectorXd& residuals, SortedNodeSampleTracker* feature_node_sort_tracker, int32_t node_id, data_size_t node_begin, data_size_t node_end, int32_t feature_index, std::vector& feature_types) { + void CalculateStrides(MatrixMap& covariates, Eigen::VectorXd& residuals, SortedNodeSampleTracker* feature_node_sort_tracker, int32_t node_id, data_size_t node_begin, data_size_t node_end, int32_t feature_index, std::vector& feature_types) { feature_cutpoint_grid_[feature_index]->CalculateStrides(covariates, residuals, feature_node_sort_tracker, node_id, node_begin, node_end, feature_index, feature_types); } @@ -177,13 +177,13 @@ class NodeCutpointTracker { ~NodeCutpointTracker() {} /*! \brief Calculate strides */ - void CalculateStrides(Eigen::MatrixXd& covariates, Eigen::VectorXd& residuals, SortedNodeSampleTracker* feature_node_sort_tracker, int32_t node_id, data_size_t node_begin, data_size_t node_end, int32_t feature_index); + void CalculateStrides(MatrixMap& covariates, Eigen::VectorXd& residuals, SortedNodeSampleTracker* feature_node_sort_tracker, int32_t node_id, data_size_t node_begin, data_size_t node_end, int32_t feature_index); /*! \brief Split numeric / ordered categorical feature and update sort indices */ - void CalculateStridesNumeric(Eigen::MatrixXd& covariates, Eigen::VectorXd& residuals, SortedNodeSampleTracker* feature_node_sort_tracker, data_size_t node_begin, data_size_t node_end, int32_t feature_index); + void CalculateStridesNumeric(MatrixMap& covariates, Eigen::VectorXd& residuals, SortedNodeSampleTracker* feature_node_sort_tracker, data_size_t node_begin, data_size_t node_end, int32_t feature_index); /*! \brief Split unordered categorical feature and update sort indices */ - void CalculateStridesCategorical(Eigen::MatrixXd& covariates, Eigen::VectorXd& residuals, SortedNodeSampleTracker* feature_node_sort_tracker, data_size_t node_begin, data_size_t node_end, int32_t feature_index); + void CalculateStridesCategorical(MatrixMap& covariates, Eigen::VectorXd& residuals, SortedNodeSampleTracker* feature_node_sort_tracker, data_size_t node_begin, data_size_t node_end, int32_t feature_index); /*! \brief Number of potential cutpoints enumerated */ int32_t NumCutpoints() {return node_stride_begin_.size();} diff --git a/include/stochtree/data.h b/include/stochtree/data.h index 129a5a51..e8abe51f 100644 --- a/include/stochtree/data.h +++ b/include/stochtree/data.h @@ -14,7 +14,7 @@ namespace StochTree { class ColumnMatrix { public: - ColumnMatrix() {} + ColumnMatrix(); ColumnMatrix(double* data_ptr, data_size_t num_row, int num_col, bool is_row_major); ~ColumnMatrix() {} double GetElement(data_size_t row_num, int32_t col_num) {return data_(row_num, col_num);} @@ -22,9 +22,9 @@ class ColumnMatrix { void LoadData(double* data_ptr, data_size_t num_row, int num_col, bool is_row_major); inline data_size_t NumRows() {return data_.rows();} inline int NumCols() {return data_.cols();} - inline Eigen::MatrixXd& GetData() {return data_;} + inline MatrixMap& GetData() {return data_;} private: - Eigen::MatrixXd data_; + MatrixMap data_; }; class ColumnVector { @@ -46,13 +46,15 @@ class ForestDataset { ForestDataset() {} ~ForestDataset() {} void AddCovariates(double* data_ptr, data_size_t num_row, int num_col, bool is_row_major) { - covariates_ = ColumnMatrix(data_ptr, num_row, num_col, is_row_major); +// covariates_ = ColumnMatrix(data_ptr, num_row, num_col, is_row_major); + covariates_.LoadData(data_ptr, num_row, num_col, is_row_major); num_observations_ = num_row; num_covariates_ = num_col; has_covariates_ = true; } void AddBasis(double* data_ptr, data_size_t num_row, int num_col, bool is_row_major) { - basis_ = ColumnMatrix(data_ptr, num_row, num_col, is_row_major); +// basis_ = ColumnMatrix(data_ptr, num_row, num_col, is_row_major); + basis_.LoadData(data_ptr, num_row, num_col, is_row_major); num_basis_ = num_col; has_basis_ = true; } @@ -69,26 +71,28 @@ class ForestDataset { inline double CovariateValue(data_size_t row, int col) {return covariates_.GetElement(row, col);} inline double BasisValue(data_size_t row, int col) {return basis_.GetElement(row, col);} inline double VarWeightValue(data_size_t row) {return var_weights_.GetElement(row);} - inline Eigen::MatrixXd& GetCovariates() {return covariates_.GetData();} - inline Eigen::MatrixXd& GetBasis() {return basis_.GetData();} + inline MatrixMap& GetCovariates() {return covariates_.GetData();} + inline MatrixMap& GetBasis() {return basis_.GetData();} inline Eigen::VectorXd& GetVarWeights() {return var_weights_.GetData();} void UpdateBasis(double* data_ptr, data_size_t num_row, int num_col, bool is_row_major) { CHECK(has_basis_); CHECK_EQ(num_col, num_basis_); - // Copy data from R / Python process memory to Eigen matrix - double temp_value; - for (data_size_t i = 0; i < num_row; ++i) { - for (int j = 0; j < num_col; ++j) { - if (is_row_major){ - // Numpy 2-d arrays are stored in "row major" order - temp_value = static_cast(*(data_ptr + static_cast(num_col) * i + j)); - } else { - // R matrices are stored in "column major" order - temp_value = static_cast(*(data_ptr + static_cast(num_row) * j + i)); - } - basis_.SetElement(i, j, temp_value); - } - } + // Update the map + basis_.LoadData(data_ptr, num_row, num_col, is_row_major); +// // Copy data from R / Python process memory to Eigen matrix +// double temp_value; +// for (data_size_t i = 0; i < num_row; ++i) { +// for (int j = 0; j < num_col; ++j) { +// if (is_row_major){ +// // Numpy 2-d arrays are stored in "row major" order +// temp_value = static_cast(*(data_ptr + static_cast(num_col) * i + j)); +// } else { +// // R matrices are stored in "column major" order +// temp_value = static_cast(*(data_ptr + static_cast(num_row) * j + i)); +// } +// basis_.SetElement(i, j, temp_value); +// } +// } } private: ColumnMatrix covariates_; @@ -124,7 +128,7 @@ class RandomEffectsDataset { inline double BasisValue(data_size_t row, int col) {return basis_.GetElement(row, col);} inline double VarWeightValue(data_size_t row) {return var_weights_.GetElement(row);} inline int32_t GroupId(data_size_t row) {return group_labels_[row];} - inline Eigen::MatrixXd& GetBasis() {return basis_.GetData();} + inline MatrixMap& GetBasis() {return basis_.GetData();} inline Eigen::VectorXd& GetVarWeights() {return var_weights_.GetData();} inline std::vector& GetGroupLabels() {return group_labels_;} private: diff --git a/include/stochtree/ensemble.h b/include/stochtree/ensemble.h index 162bfc74..afbd4c79 100644 --- a/include/stochtree/ensemble.h +++ b/include/stochtree/ensemble.h @@ -90,11 +90,11 @@ class TreeEnsemble { } } - inline void PredictInplace(Eigen::MatrixXd& covariates, Eigen::MatrixXd& basis, std::vector &output, data_size_t offset = 0) { + inline void PredictInplace(MatrixMap& covariates, MatrixMap& basis, std::vector &output, data_size_t offset = 0) { PredictInplace(covariates, basis, output, 0, trees_.size(), offset); } - inline void PredictInplace(Eigen::MatrixXd& covariates, Eigen::MatrixXd& basis, std::vector &output, + inline void PredictInplace(MatrixMap& covariates, MatrixMap& basis, std::vector &output, int tree_begin, int tree_end, data_size_t offset = 0) { double pred; CHECK_EQ(covariates.rows(), basis.rows()); @@ -118,11 +118,11 @@ class TreeEnsemble { } } - inline void PredictInplace(Eigen::MatrixXd& covariates, std::vector &output, data_size_t offset = 0) { + inline void PredictInplace(MatrixMap& covariates, std::vector &output, data_size_t offset = 0) { PredictInplace(covariates, output, 0, trees_.size(), offset); } - inline void PredictInplace(Eigen::MatrixXd& covariates, std::vector &output, int tree_begin, int tree_end, data_size_t offset = 0) { + inline void PredictInplace(MatrixMap& covariates, std::vector &output, int tree_begin, int tree_end, data_size_t offset = 0) { double pred; data_size_t n = covariates.rows(); data_size_t total_output_size = n; @@ -147,7 +147,7 @@ class TreeEnsemble { inline void PredictRawInplace(ForestDataset& dataset, std::vector &output, int tree_begin, int tree_end, data_size_t offset = 0) { double pred; - Eigen::MatrixXd covariates = dataset.GetCovariates(); + MatrixMap covariates = dataset.GetCovariates(); CHECK_EQ(output_dimension_, trees_[0]->OutputDimension()); data_size_t n = covariates.rows(); data_size_t total_output_size = n * output_dimension_; diff --git a/include/stochtree/meta.h b/include/stochtree/meta.h index b77179ec..5ae002e2 100644 --- a/include/stochtree/meta.h +++ b/include/stochtree/meta.h @@ -18,6 +18,7 @@ #include #include #include +#include #if (defined(_MSC_VER) && (defined(_M_IX86) || defined(_M_AMD64))) || defined(__INTEL_COMPILER) || MM_PREFETCH #include @@ -63,6 +64,14 @@ enum RandomEffectsType { kRegressionRandomEffect }; +/*! \brief Eigen Map objects that expose matrix / vector operations directly on raw buffers without copying data */ +typedef Eigen::Matrix MatrixObject; +typedef Eigen::Matrix VectorObject; + +/*! \brief Eigen Map objects that expose matrix / vector operations directly on raw buffers without copying data */ +typedef Eigen::Map> MatrixMap; +typedef Eigen::Map> VectorMap; + /*! \brief Type of data size */ typedef int32_t data_size_t; diff --git a/include/stochtree/partition_tracker.h b/include/stochtree/partition_tracker.h index 6025496c..2e949722 100644 --- a/include/stochtree/partition_tracker.h +++ b/include/stochtree/partition_tracker.h @@ -49,15 +49,15 @@ class FeaturePresortRootContainer; /*! \brief Wrapper around various data structures for forest sampling algorithms */ class ForestTracker { public: - ForestTracker(Eigen::MatrixXd& covariates, std::vector& feature_types, int num_trees, int num_observations); + ForestTracker(MatrixMap& covariates, std::vector& feature_types, int num_trees, int num_observations); ~ForestTracker() {} void AssignAllSamplesToRoot(); void AssignAllSamplesToRoot(int32_t tree_num); void AssignAllSamplesToConstantPrediction(double value); void AssignAllSamplesToConstantPrediction(int32_t tree_num, double value); - void ResetRoot(Eigen::MatrixXd& covariates, std::vector& feature_types, int32_t tree_num); - void AddSplit(Eigen::MatrixXd& covariates, TreeSplit& split, int32_t split_feature, int32_t tree_id, int32_t split_node_id, int32_t left_node_id, int32_t right_node_id, bool keep_sorted = false); - void RemoveSplit(Eigen::MatrixXd& covariates, Tree* tree, int32_t tree_id, int32_t split_node_id, int32_t left_node_id, int32_t right_node_id, bool keep_sorted = false); + void ResetRoot(MatrixMap& covariates, std::vector& feature_types, int32_t tree_num); + void AddSplit(MatrixMap& covariates, TreeSplit& split, int32_t split_feature, int32_t tree_id, int32_t split_node_id, int32_t left_node_id, int32_t right_node_id, bool keep_sorted = false); + void RemoveSplit(MatrixMap& covariates, Tree* tree, int32_t tree_id, int32_t split_node_id, int32_t left_node_id, int32_t right_node_id, bool keep_sorted = false); double GetTreeSamplePrediction(data_size_t sample_id, int tree_id); void SetTreeSamplePrediction(data_size_t sample_id, int tree_id, double value); data_size_t GetNodeId(int observation_num, int tree_num); @@ -163,9 +163,8 @@ class SampleNodeMapper { } } - void AddSplit(Eigen::MatrixXd& covariates, TreeSplit& split, int32_t split_feature, int32_t tree_id, int32_t split_node_id, int32_t left_node_id, int32_t right_node_id) { + void AddSplit(MatrixMap& covariates, TreeSplit& split, int32_t split_feature, int32_t tree_id, int32_t split_node_id, int32_t left_node_id, int32_t right_node_id) { CHECK_EQ(num_observations_, covariates.rows()); - // Eigen::MatrixXd X = covariates.GetData(); for (int i = 0; i < num_observations_; i++) { if (tree_observation_indices_[tree_id][i] == split_node_id) { auto fvalue = covariates(i, split_feature); @@ -212,13 +211,13 @@ class FeatureUnsortedPartition { FeatureUnsortedPartition(data_size_t n); /*! \brief Partition a node based on a new split rule */ - void PartitionNode(Eigen::MatrixXd& covariates, int node_id, int left_node_id, int right_node_id, int feature_split, TreeSplit& split); + void PartitionNode(MatrixMap& covariates, int node_id, int left_node_id, int right_node_id, int feature_split, TreeSplit& split); /*! \brief Partition a node based on a new split rule */ - void PartitionNode(Eigen::MatrixXd& covariates, int node_id, int left_node_id, int right_node_id, int feature_split, double split_value); + void PartitionNode(MatrixMap& covariates, int node_id, int left_node_id, int right_node_id, int feature_split, double split_value); /*! \brief Partition a node based on a new split rule */ - void PartitionNode(Eigen::MatrixXd& covariates, int node_id, int left_node_id, int right_node_id, int feature_split, std::vector const& category_list); + void PartitionNode(MatrixMap& covariates, int node_id, int left_node_id, int right_node_id, int feature_split, std::vector const& category_list); /*! \brief Convert a (currently split) node to a leaf */ void PruneNodeToLeaf(int node_id); @@ -289,17 +288,17 @@ class UnsortedNodeSampleTracker { } /*! \brief Partition a node based on a new split rule */ - void PartitionTreeNode(Eigen::MatrixXd& covariates, int tree_id, int node_id, int left_node_id, int right_node_id, int feature_split, TreeSplit& split) { + void PartitionTreeNode(MatrixMap& covariates, int tree_id, int node_id, int left_node_id, int right_node_id, int feature_split, TreeSplit& split) { return feature_partitions_[tree_id]->PartitionNode(covariates, node_id, left_node_id, right_node_id, feature_split, split); } /*! \brief Partition a node based on a new split rule */ - void PartitionTreeNode(Eigen::MatrixXd& covariates, int tree_id, int node_id, int left_node_id, int right_node_id, int feature_split, double split_value) { + void PartitionTreeNode(MatrixMap& covariates, int tree_id, int node_id, int left_node_id, int right_node_id, int feature_split, double split_value) { return feature_partitions_[tree_id]->PartitionNode(covariates, node_id, left_node_id, right_node_id, feature_split, split_value); } /*! \brief Partition a node based on a new split rule */ - void PartitionTreeNode(Eigen::MatrixXd& covariates, int tree_id, int node_id, int left_node_id, int right_node_id, int feature_split, std::vector const& category_list) { + void PartitionTreeNode(MatrixMap& covariates, int tree_id, int node_id, int left_node_id, int right_node_id, int feature_split, std::vector const& category_list) { return feature_partitions_[tree_id]->PartitionNode(covariates, node_id, left_node_id, right_node_id, feature_split, category_list); } @@ -448,14 +447,14 @@ class FeaturePresortPartition; class FeaturePresortRoot { friend FeaturePresortPartition; public: - FeaturePresortRoot(Eigen::MatrixXd& covariates, int32_t feature_index, FeatureType feature_type) { + FeaturePresortRoot(MatrixMap& covariates, int32_t feature_index, FeatureType feature_type) { feature_index_ = feature_index; ArgsortRoot(covariates); } ~FeaturePresortRoot() {} - void ArgsortRoot(Eigen::MatrixXd& covariates) { + void ArgsortRoot(MatrixMap& covariates) { data_size_t num_obs = covariates.rows(); // Make a vector of indices from 0 to num_obs - 1 @@ -479,7 +478,7 @@ class FeaturePresortRoot { /*! \brief Container class for FeaturePresortRoot objects stored for every feature in a dataset */ class FeaturePresortRootContainer { public: - FeaturePresortRootContainer(Eigen::MatrixXd& covariates, std::vector& feature_types) { + FeaturePresortRootContainer(MatrixMap& covariates, std::vector& feature_types) { num_features_ = covariates.cols(); feature_presort_.resize(num_features_); for (int i = 0; i < num_features_; i++) { @@ -508,7 +507,7 @@ class FeaturePresortRootContainer { */ class FeaturePresortPartition { public: - FeaturePresortPartition(FeaturePresortRoot* feature_presort_root, Eigen::MatrixXd& covariates, int32_t feature_index, FeatureType feature_type) { + FeaturePresortPartition(FeaturePresortRoot* feature_presort_root, MatrixMap& covariates, int32_t feature_index, FeatureType feature_type) { // Unpack all feature details feature_index_ = feature_index; feature_type_ = feature_type; @@ -523,13 +522,13 @@ class FeaturePresortPartition { ~FeaturePresortPartition() {} /*! \brief Split numeric / ordered categorical feature and update sort indices */ - void SplitFeature(Eigen::MatrixXd& covariates, int32_t node_id, int32_t feature_index, TreeSplit& split); + void SplitFeature(MatrixMap& covariates, int32_t node_id, int32_t feature_index, TreeSplit& split); /*! \brief Split numeric / ordered categorical feature and update sort indices */ - void SplitFeatureNumeric(Eigen::MatrixXd& covariates, int32_t node_id, int32_t feature_index, double split_value); + void SplitFeatureNumeric(MatrixMap& covariates, int32_t node_id, int32_t feature_index, double split_value); /*! \brief Split unordered categorical feature and update sort indices */ - void SplitFeatureCategorical(Eigen::MatrixXd& covariates, int32_t node_id, int32_t feature_index, std::vector const& category_list); + void SplitFeatureCategorical(MatrixMap& covariates, int32_t node_id, int32_t feature_index, std::vector const& category_list); /*! \brief Start position of node indexed by node_id */ data_size_t NodeBegin(int32_t node_id) {return node_offset_sizes_[node_id].Begin();} @@ -568,7 +567,7 @@ class FeaturePresortPartition { /*! \brief Data structure for tracking observations through a tree partition with each feature pre-sorted */ class SortedNodeSampleTracker { public: - SortedNodeSampleTracker(FeaturePresortRootContainer* feature_presort_root_container, Eigen::MatrixXd& covariates, std::vector& feature_types) { + SortedNodeSampleTracker(FeaturePresortRootContainer* feature_presort_root_container, MatrixMap& covariates, std::vector& feature_types) { num_features_ = covariates.cols(); feature_partitions_.resize(num_features_); FeaturePresortRoot* feature_presort_root; @@ -579,21 +578,21 @@ class SortedNodeSampleTracker { } /*! \brief Partition a node based on a new split rule */ - void PartitionNode(Eigen::MatrixXd& covariates, int node_id, int feature_split, TreeSplit& split) { + void PartitionNode(MatrixMap& covariates, int node_id, int feature_split, TreeSplit& split) { for (int i = 0; i < num_features_; i++) { feature_partitions_[i]->SplitFeature(covariates, node_id, feature_split, split); } } /*! \brief Partition a node based on a new split rule */ - void PartitionNode(Eigen::MatrixXd& covariates, int node_id, int feature_split, double split_value) { + void PartitionNode(MatrixMap& covariates, int node_id, int feature_split, double split_value) { for (int i = 0; i < num_features_; i++) { feature_partitions_[i]->SplitFeatureNumeric(covariates, node_id, feature_split, split_value); } } /*! \brief Partition a node based on a new split rule */ - void PartitionNode(Eigen::MatrixXd& covariates, int node_id, int feature_split, std::vector const& category_list) { + void PartitionNode(MatrixMap& covariates, int node_id, int feature_split, std::vector const& category_list) { for (int i = 0; i < num_features_; i++) { feature_partitions_[i]->SplitFeatureCategorical(covariates, node_id, feature_split, category_list); } diff --git a/include/stochtree/random_effects.h b/include/stochtree/random_effects.h index 6f145101..12b0e6a6 100644 --- a/include/stochtree/random_effects.h +++ b/include/stochtree/random_effects.h @@ -156,7 +156,7 @@ class RandomEffectsTerm { inline int32_t NumGroups() {return num_groups_;} std::vector Predict(RandomEffectsDataset& rfx_dataset) { - Eigen::MatrixXd X = rfx_dataset.GetBasis(); + MatrixMap X = rfx_dataset.GetBasis(); std::vector group_labels = rfx_dataset.GetGroupLabels(); CHECK_EQ(X.rows(), group_labels.size()); int n = X.rows(); diff --git a/include/stochtree/tree.h b/include/stochtree/tree.h index 2dc7126d..c4ce832e 100644 --- a/include/stochtree/tree.h +++ b/include/stochtree/tree.h @@ -224,9 +224,9 @@ class Tree { */ void InplacePredictFromNodes(std::vector result, std::vector node_indices); std::vector PredictFromNodes(std::vector node_indices); - std::vector PredictFromNodes(std::vector node_indices, Eigen::MatrixXd& basis); + std::vector PredictFromNodes(std::vector node_indices, MatrixMap& basis); double PredictFromNode(std::int32_t node_id); - double PredictFromNode(std::int32_t node_id, Eigen::MatrixXd& basis, int row_idx); + double PredictFromNode(std::int32_t node_id, MatrixMap& basis, int row_idx); /** Getters **/ /*! @@ -697,7 +697,7 @@ inline int NextNodeCategorical(double fvalue, std::vector const& * \param data Dataset used for prediction * \param row Row indexing the prediction observation */ -inline int EvaluateTree(Tree const& tree, Eigen::MatrixXd& data, int row) { +inline int EvaluateTree(Tree const& tree, MatrixMap& data, int row) { int node_id = 0; while (!tree.IsLeaf(node_id)) { auto const split_index = tree.SplitIndex(node_id); @@ -722,7 +722,7 @@ inline int EvaluateTree(Tree const& tree, Eigen::MatrixXd& data, int row) { * \param split_index Column of new split * \param split_value Value defining the split */ -inline bool RowSplitLeft(Eigen::MatrixXd& covariates, int row, int split_index, double split_value) { +inline bool RowSplitLeft(MatrixMap& covariates, int row, int split_index, double split_value) { double const fvalue = covariates(row, split_index); return SplitTrueNumeric(fvalue, split_value); } @@ -733,7 +733,7 @@ inline bool RowSplitLeft(Eigen::MatrixXd& covariates, int row, int split_index, * \param split_index Column of new split * \param category_list Categories defining the split */ -inline bool RowSplitLeft(Eigen::MatrixXd& covariates, int row, int split_index, std::vector const& category_list) { +inline bool RowSplitLeft(MatrixMap& covariates, int row, int split_index, std::vector const& category_list) { double const fvalue = covariates(row, split_index); return SplitTrueCategorical(fvalue, category_list); } diff --git a/src/cutpoint_candidates.cpp b/src/cutpoint_candidates.cpp index 4a0845c7..c686f176 100644 --- a/src/cutpoint_candidates.cpp +++ b/src/cutpoint_candidates.cpp @@ -6,7 +6,7 @@ namespace StochTree { -void FeatureCutpointGrid::CalculateStrides(Eigen::MatrixXd& covariates, Eigen::VectorXd& residuals, SortedNodeSampleTracker* feature_node_sort_tracker, int32_t node_id, data_size_t node_begin, data_size_t node_end, int32_t feature_index, std::vector& feature_types) { +void FeatureCutpointGrid::CalculateStrides(MatrixMap& covariates, Eigen::VectorXd& residuals, SortedNodeSampleTracker* feature_node_sort_tracker, int32_t node_id, data_size_t node_begin, data_size_t node_end, int32_t feature_index, std::vector& feature_types) { // Reset the stride vectors node_stride_begin_.clear(); node_stride_length_.clear(); @@ -23,7 +23,7 @@ void FeatureCutpointGrid::CalculateStrides(Eigen::MatrixXd& covariates, Eigen::V } } -void FeatureCutpointGrid::CalculateStridesNumeric(Eigen::MatrixXd& covariates, Eigen::VectorXd& residuals, SortedNodeSampleTracker* feature_node_sort_tracker, int32_t node_id, data_size_t node_begin, data_size_t node_end, int32_t feature_index) { +void FeatureCutpointGrid::CalculateStridesNumeric(MatrixMap& covariates, Eigen::VectorXd& residuals, SortedNodeSampleTracker* feature_node_sort_tracker, int32_t node_id, data_size_t node_begin, data_size_t node_end, int32_t feature_index) { data_size_t node_size = node_end - node_begin; // Check if node has fewer observations than cutpoint_grid_size if (node_size <= cutpoint_grid_size_) { @@ -41,7 +41,7 @@ void FeatureCutpointGrid::CalculateStridesNumeric(Eigen::MatrixXd& covariates, E } } -void FeatureCutpointGrid::CalculateStridesOrderedCategorical(Eigen::MatrixXd& covariates, Eigen::VectorXd& residuals, SortedNodeSampleTracker* feature_node_sort_tracker, int32_t node_id, data_size_t node_begin, data_size_t node_end, int32_t feature_index) { +void FeatureCutpointGrid::CalculateStridesOrderedCategorical(MatrixMap& covariates, Eigen::VectorXd& residuals, SortedNodeSampleTracker* feature_node_sort_tracker, int32_t node_id, data_size_t node_begin, data_size_t node_end, int32_t feature_index) { data_size_t node_size = node_end - node_begin; // Edge case 1: single observation @@ -103,7 +103,7 @@ void FeatureCutpointGrid::CalculateStridesOrderedCategorical(Eigen::MatrixXd& co } } -void FeatureCutpointGrid::CalculateStridesUnorderedCategorical(Eigen::MatrixXd& covariates, Eigen::VectorXd& residuals, SortedNodeSampleTracker* feature_node_sort_tracker, int32_t node_id, data_size_t node_begin, data_size_t node_end, int32_t feature_index) { +void FeatureCutpointGrid::CalculateStridesUnorderedCategorical(MatrixMap& covariates, Eigen::VectorXd& residuals, SortedNodeSampleTracker* feature_node_sort_tracker, int32_t node_id, data_size_t node_begin, data_size_t node_end, int32_t feature_index) { // TODO: refactor so that this initial code is shared between ordered and unordered categorical cutpoint calculation data_size_t node_size = node_end - node_begin; std::vector bin_sums; @@ -199,7 +199,7 @@ void FeatureCutpointGrid::CalculateStridesUnorderedCategorical(Eigen::MatrixXd& } } -void FeatureCutpointGrid::EnumerateNumericCutpointsDeduplication(Eigen::MatrixXd& covariates, Eigen::VectorXd& residuals, SortedNodeSampleTracker* feature_node_sort_tracker, int32_t node_id, data_size_t node_begin, data_size_t node_end, data_size_t node_size, int32_t feature_index) { +void FeatureCutpointGrid::EnumerateNumericCutpointsDeduplication(MatrixMap& covariates, Eigen::VectorXd& residuals, SortedNodeSampleTracker* feature_node_sort_tracker, int32_t node_id, data_size_t node_begin, data_size_t node_end, data_size_t node_size, int32_t feature_index) { // Edge case 1: single observation double single_value; if (node_end - node_begin == 1) { @@ -258,7 +258,7 @@ void FeatureCutpointGrid::EnumerateNumericCutpointsDeduplication(Eigen::MatrixXd } } -void FeatureCutpointGrid::ScanNumericCutpoints(Eigen::MatrixXd& covariates, Eigen::VectorXd& residuals, SortedNodeSampleTracker* feature_node_sort_tracker, int32_t node_id, data_size_t node_begin, data_size_t node_end, data_size_t node_size, int32_t feature_index) { +void FeatureCutpointGrid::ScanNumericCutpoints(MatrixMap& covariates, Eigen::VectorXd& residuals, SortedNodeSampleTracker* feature_node_sort_tracker, int32_t node_id, data_size_t node_begin, data_size_t node_end, data_size_t node_size, int32_t feature_index) { // Edge case 1: single observation double single_value; if (node_end - node_begin == 1) { diff --git a/src/data.cpp b/src/data.cpp index ea667bce..6f4dc344 100644 --- a/src/data.cpp +++ b/src/data.cpp @@ -5,27 +5,18 @@ namespace StochTree { -ColumnMatrix::ColumnMatrix(double* data_ptr, data_size_t num_row, int num_col, bool is_row_major) { +ColumnMatrix::ColumnMatrix() : data_(NULL,1,1) { + // Eigen::Map does not have a default initializer, + // so we initialize the data member with a null pointer + // and modify later +} + +ColumnMatrix::ColumnMatrix(double* data_ptr, data_size_t num_row, int num_col, bool is_row_major) : data_(NULL,1,1) { LoadData(data_ptr, num_row, num_col, is_row_major); } void ColumnMatrix::LoadData(double* data_ptr, data_size_t num_row, int num_col, bool is_row_major) { - data_.resize(num_row, num_col); - - // Copy data from R / Python process memory to Eigen matrix - double temp_value; - for (data_size_t i = 0; i < num_row; ++i) { - for (int j = 0; j < num_col; ++j) { - if (is_row_major){ - // Numpy 2-d arrays are stored in "row major" order - temp_value = static_cast(*(data_ptr + static_cast(num_col) * i + j)); - } else { - // R matrices are stored in "column major" order - temp_value = static_cast(*(data_ptr + static_cast(num_row) * j + i)); - } - data_(i, j) = temp_value; - } - } + new (&data_) MatrixMap(data_ptr,num_row,num_col); } ColumnVector::ColumnVector(double* data_ptr, data_size_t num_row) { @@ -43,34 +34,4 @@ void ColumnVector::LoadData(double* data_ptr, data_size_t num_row) { } } -void LoadData(double* data_ptr, int num_row, int num_col, bool is_row_major, Eigen::MatrixXd& data_matrix) { - data_matrix.resize(num_row, num_col); - - // Copy data from R / Python process memory to Eigen matrix - double temp_value; - for (data_size_t i = 0; i < num_row; ++i) { - for (int j = 0; j < num_col; ++j) { - if (is_row_major){ - // Numpy 2-d arrays are stored in "row major" order - temp_value = static_cast(*(data_ptr + static_cast(num_col) * i + j)); - } else { - // R matrices are stored in "column major" order - temp_value = static_cast(*(data_ptr + static_cast(num_row) * j + i)); - } - data_matrix(i, j) = temp_value; - } - } -} - -void LoadData(double* data_ptr, int num_row, Eigen::VectorXd& data_vector) { - data_vector.resize(num_row); - - // Copy data from R / Python process memory to Eigen matrix - double temp_value; - for (data_size_t i = 0; i < num_row; ++i) { - temp_value = static_cast(*(data_ptr + i)); - data_vector(i) = temp_value; - } -} - } // namespace StochTree diff --git a/src/leaf_model.cpp b/src/leaf_model.cpp index 3b522855..ad7922b7 100644 --- a/src/leaf_model.cpp +++ b/src/leaf_model.cpp @@ -150,7 +150,7 @@ void GaussianConstantLeafModel::EvaluateAllPossibleSplits(ForestDataset& dataset double no_split_log_ml = NoSplitLogMarginalLikelihood(node_suff_stat, global_variance); // Unpack data - Eigen::MatrixXd covariates = dataset.GetCovariates(); + MatrixMap covariates = dataset.GetCovariates(); Eigen::VectorXd outcome = residual.GetData(); Eigen::VectorXd var_weights; bool has_weights = dataset.HasVarWeights(); @@ -346,7 +346,7 @@ void GaussianUnivariateRegressionLeafModel::EvaluateAllPossibleSplits(ForestData double no_split_log_ml = NoSplitLogMarginalLikelihood(node_suff_stat, global_variance); // Unpack data - Eigen::MatrixXd covariates = dataset.GetCovariates(); + MatrixMap covariates = dataset.GetCovariates(); Eigen::VectorXd outcome = residual.GetData(); Eigen::VectorXd var_weights; bool has_weights = dataset.HasVarWeights(); @@ -545,7 +545,7 @@ void GaussianMultivariateRegressionLeafModel::EvaluateAllPossibleSplits(ForestDa double no_split_log_ml = NoSplitLogMarginalLikelihood(node_suff_stat, global_variance); // Unpack data - Eigen::MatrixXd covariates = dataset.GetCovariates(); + MatrixMap covariates = dataset.GetCovariates(); Eigen::VectorXd outcome = residual.GetData(); Eigen::VectorXd var_weights; bool has_weights = dataset.HasVarWeights(); diff --git a/src/partition_tracker.cpp b/src/partition_tracker.cpp index 1b0caf28..3c6bda5e 100644 --- a/src/partition_tracker.cpp +++ b/src/partition_tracker.cpp @@ -4,6 +4,7 @@ * license information. */ #include +#include #include #include @@ -14,7 +15,7 @@ namespace StochTree { -ForestTracker::ForestTracker(Eigen::MatrixXd& covariates, std::vector& feature_types, int num_trees, int num_observations) { +ForestTracker::ForestTracker(MatrixMap& covariates, std::vector& feature_types, int num_trees, int num_observations) { sample_pred_mapper_ = std::make_unique(num_trees, num_observations); sample_node_mapper_ = std::make_unique(num_trees, num_observations); unsorted_node_sample_tracker_ = std::make_unique(num_observations, num_trees); @@ -27,7 +28,7 @@ ForestTracker::ForestTracker(Eigen::MatrixXd& covariates, std::vector& feature_types, int32_t tree_num) { +void ForestTracker::ResetRoot(MatrixMap& covariates, std::vector& feature_types, int32_t tree_num) { AssignAllSamplesToRoot(tree_num); unsorted_node_sample_tracker_->ResetTreeToRoot(tree_num, covariates.rows()); sorted_node_sample_tracker_.reset(new SortedNodeSampleTracker(presort_container_.get(), covariates, feature_types)); @@ -80,7 +81,7 @@ void ForestTracker::AssignAllSamplesToConstantPrediction(int32_t tree_num, doubl sample_pred_mapper_->AssignAllSamplesToConstantPrediction(tree_num, value); } -void ForestTracker::AddSplit(Eigen::MatrixXd& covariates, TreeSplit& split, int32_t split_feature, int32_t tree_id, int32_t split_node_id, int32_t left_node_id, int32_t right_node_id, bool keep_sorted) { +void ForestTracker::AddSplit(MatrixMap& covariates, TreeSplit& split, int32_t split_feature, int32_t tree_id, int32_t split_node_id, int32_t left_node_id, int32_t right_node_id, bool keep_sorted) { sample_node_mapper_->AddSplit(covariates, split, split_feature, tree_id, split_node_id, left_node_id, right_node_id); unsorted_node_sample_tracker_->PartitionTreeNode(covariates, tree_id, split_node_id, left_node_id, right_node_id, split_feature, split); if (keep_sorted) { @@ -88,7 +89,7 @@ void ForestTracker::AddSplit(Eigen::MatrixXd& covariates, TreeSplit& split, int3 } } -void ForestTracker::RemoveSplit(Eigen::MatrixXd& covariates, Tree* tree, int32_t tree_id, int32_t split_node_id, int32_t left_node_id, int32_t right_node_id, bool keep_sorted) { +void ForestTracker::RemoveSplit(MatrixMap& covariates, Tree* tree, int32_t tree_id, int32_t split_node_id, int32_t left_node_id, int32_t right_node_id, bool keep_sorted) { unsorted_node_sample_tracker_->PruneTreeNodeToLeaf(tree_id, split_node_id); unsorted_node_sample_tracker_->UpdateObservationMapping(tree, tree_id, sample_node_mapper_.get()); // TODO: WARN if this is called from the GFR Tree Sampler @@ -138,7 +139,7 @@ int FeatureUnsortedPartition::RightNode(int node_id) { return right_nodes_[node_id]; } -void FeatureUnsortedPartition::PartitionNode(Eigen::MatrixXd& covariates, int node_id, int left_node_id, int right_node_id, int feature_split, TreeSplit& split) { +void FeatureUnsortedPartition::PartitionNode(MatrixMap& covariates, int node_id, int left_node_id, int right_node_id, int feature_split, TreeSplit& split) { // Partition-related values data_size_t node_start_idx = node_begin_[node_id]; data_size_t num_node_elements = node_length_[node_id]; @@ -157,7 +158,7 @@ void FeatureUnsortedPartition::PartitionNode(Eigen::MatrixXd& covariates, int no ExpandNodeTrackingVectors(node_id, left_node_id, right_node_id, node_start_idx, num_true, num_false); } -void FeatureUnsortedPartition::PartitionNode(Eigen::MatrixXd& covariates, int node_id, int left_node_id, int right_node_id, int feature_split, double split_value) { +void FeatureUnsortedPartition::PartitionNode(MatrixMap& covariates, int node_id, int left_node_id, int right_node_id, int feature_split, double split_value) { // Partition-related values data_size_t node_start_idx = node_begin_[node_id]; data_size_t num_node_elements = node_length_[node_id]; @@ -176,7 +177,7 @@ void FeatureUnsortedPartition::PartitionNode(Eigen::MatrixXd& covariates, int no ExpandNodeTrackingVectors(node_id, left_node_id, right_node_id, node_start_idx, num_true, num_false); } -void FeatureUnsortedPartition::PartitionNode(Eigen::MatrixXd& covariates, int node_id, int left_node_id, int right_node_id, int feature_split, std::vector const& category_list) { +void FeatureUnsortedPartition::PartitionNode(MatrixMap& covariates, int node_id, int left_node_id, int right_node_id, int feature_split, std::vector const& category_list) { // Partition-related values data_size_t node_start_idx = node_begin_[node_id]; data_size_t num_node_elements = node_length_[node_id]; @@ -309,7 +310,7 @@ void FeaturePresortPartition::AddLeftRightNodes(data_size_t left_node_begin, dat node_offset_sizes_.emplace_back(right_node_begin, right_node_size); } -void FeaturePresortPartition::SplitFeature(Eigen::MatrixXd& covariates, int32_t node_id, int32_t feature_index, TreeSplit& split) { +void FeaturePresortPartition::SplitFeature(MatrixMap& covariates, int32_t node_id, int32_t feature_index, TreeSplit& split) { // Partition-related values data_size_t node_start_idx = NodeBegin(node_id); data_size_t node_end_idx = NodeEnd(node_id); @@ -327,7 +328,7 @@ void FeaturePresortPartition::SplitFeature(Eigen::MatrixXd& covariates, int32_t AddLeftRightNodes(node_start_idx, num_true, node_start_idx + num_true, num_false); } -void FeaturePresortPartition::SplitFeatureNumeric(Eigen::MatrixXd& covariates, int32_t node_id, int32_t feature_index, double split_value) { +void FeaturePresortPartition::SplitFeatureNumeric(MatrixMap& covariates, int32_t node_id, int32_t feature_index, double split_value) { // Partition-related values data_size_t node_start_idx = NodeBegin(node_id); data_size_t node_end_idx = NodeEnd(node_id); @@ -345,7 +346,7 @@ void FeaturePresortPartition::SplitFeatureNumeric(Eigen::MatrixXd& covariates, i AddLeftRightNodes(node_start_idx, num_true, node_start_idx + num_true, num_false); } -void FeaturePresortPartition::SplitFeatureCategorical(Eigen::MatrixXd& covariates, int32_t node_id, int32_t feature_index, std::vector const& category_list) { +void FeaturePresortPartition::SplitFeatureCategorical(MatrixMap& covariates, int32_t node_id, int32_t feature_index, std::vector const& category_list) { // Partition-related values data_size_t node_start_idx = NodeBegin(node_id); data_size_t node_end_idx = NodeEnd(node_id); diff --git a/src/tree.cpp b/src/tree.cpp index 0dd73b43..ae6eb667 100644 --- a/src/tree.cpp +++ b/src/tree.cpp @@ -5,6 +5,7 @@ * Copyright 2017-2021 by [treelite] Contributors */ #include +#include #include #include @@ -64,7 +65,7 @@ std::vector Tree::PredictFromNodes(std::vector node_indice return result; } -double Tree::PredictFromNode(std::int32_t node_id, Eigen::MatrixXd& basis, int row_idx) { +double Tree::PredictFromNode(std::int32_t node_id, MatrixMap& basis, int row_idx) { if (!this->IsLeaf(node_id)) { Log::Fatal("Node %d is not a leaf node", node_id); } @@ -75,7 +76,7 @@ double Tree::PredictFromNode(std::int32_t node_id, Eigen::MatrixXd& basis, int r return pred; } -std::vector Tree::PredictFromNodes(std::vector node_indices, Eigen::MatrixXd& basis) { +std::vector Tree::PredictFromNodes(std::vector node_indices, MatrixMap& basis) { data_size_t n = node_indices.size(); std::vector result(n); for (data_size_t i = 0; i < n; i++) { diff --git a/test/testutils.cpp b/test/testutils.cpp index 13fc0503..f49353b2 100644 --- a/test/testutils.cpp +++ b/test/testutils.cpp @@ -26,26 +26,24 @@ TestDataset LoadSmallDatasetUnivariateBasis() { output.outcome.resize(output.n); // Covariates - output.covariates << 0.766969853, 0.83894646, 0.63649772, 0.6747788934, 0.27398269, - 0.634970996, 0.15237997, 0.3800786, 0.6457891271, 0.21604451, - 0.229598754, 0.12461481, 0.81407372, 0.364336529, 0.45160373, - 0.741084778, 0.53356288, 0.58940162, 0.9995219493, 0.19142269, - 0.618177813, 0.88876378, 0.51174404, 0.8827708189, 0.12730742, - 0.858657839, 0.9271676, 0.5115294, 0.67865624, 0.28658962, - 0.719224842, 0.0546961, 0.42850897, 0.260336376, 0.1371501, - 0.747422328, 0.87172033, 0.98791964, 0.4018020707, 0.29145664, - 0.3158837, 0.39253551, 0.83610831, 0.0101785748, 0.1955386, - 0.419554105, 0.5586495, 0.19908607, 0.4873921743, 0.35568569; + output.covariates = std::vector { + 0.76696985, 0.63497100, 0.22959875, 0.74108478, 0.61817781, 0.85865784, 0.71922484, 0.74742233, 0.31588370, + 0.41955411, 0.83894646, 0.15237997, 0.12461481, 0.53356288, 0.88876378, 0.92716760, 0.05469610, 0.87172033, + 0.39253551, 0.55864950, 0.63649772, 0.38007860, 0.81407372, 0.58940162, 0.51174404, 0.51152940, 0.42850897, + 0.98791964, 0.83610831, 0.19908607, 0.67477889, 0.64578913, 0.36433653, 0.99952195, 0.88277082, 0.67865624, + 0.26033638, 0.40180207, 0.01017857, 0.48739217, 0.27398269, 0.21604451, 0.45160373, 0.19142269, 0.12730742, + 0.28658962, 0.13715010, 0.29145664, 0.19553860, 0.35568569 + }; // Leaf regression basis - output.omega << 0.97801674, 0.34045661, 0.20528387, 0.76230322, 0.63244655, 0.61225851, 0.40492125, 0.33112223, 0.86917047, 0.58444831; + output.omega = std::vector {0.97801674, 0.34045661, 0.20528387, 0.76230322, 0.63244655, 0.61225851, 0.40492125, 0.33112223, 0.86917047, 0.58444831}; // Outcome output.outcome << 2.158854445, 1.175387297, 0.40481061, 1.751578365, 0.299641379, 0.347249942, 0.546179903, 1.164750138, 3.389946886, -0.605464414; // Random effects regression basis (i.e. constant, intercept-only RFX model) - output.rfx_basis << 1, 1, 1, 1, 1, 1, 1, 1, 1, 1; + output.rfx_basis = std::vector {1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; // Random effects group labels for (int i = 0; i < output.n/2; i++) { @@ -74,35 +72,28 @@ TestDataset LoadSmallDatasetMultivariateBasis() { output.outcome.resize(output.n); // Covariates - output.covariates << 0.766969853, 0.83894646, 0.63649772, 0.6747788934, 0.27398269, - 0.634970996, 0.15237997, 0.3800786, 0.6457891271, 0.21604451, - 0.229598754, 0.12461481, 0.81407372, 0.364336529, 0.45160373, - 0.741084778, 0.53356288, 0.58940162, 0.9995219493, 0.19142269, - 0.618177813, 0.88876378, 0.51174404, 0.8827708189, 0.12730742, - 0.858657839, 0.9271676, 0.5115294, 0.67865624, 0.28658962, - 0.719224842, 0.0546961, 0.42850897, 0.260336376, 0.1371501, - 0.747422328, 0.87172033, 0.98791964, 0.4018020707, 0.29145664, - 0.3158837, 0.39253551, 0.83610831, 0.0101785748, 0.1955386, - 0.419554105, 0.5586495, 0.19908607, 0.4873921743, 0.35568569; + output.covariates = std::vector { + 0.76696985, 0.63497100, 0.22959875, 0.74108478, 0.61817781, 0.85865784, 0.71922484, 0.74742233, 0.31588370, + 0.41955411, 0.83894646, 0.15237997, 0.12461481, 0.53356288, 0.88876378, 0.92716760, 0.05469610, 0.87172033, + 0.39253551, 0.55864950, 0.63649772, 0.38007860, 0.81407372, 0.58940162, 0.51174404, 0.51152940, 0.42850897, + 0.98791964, 0.83610831, 0.19908607, 0.67477889, 0.64578913, 0.36433653, 0.99952195, 0.88277082, 0.67865624, + 0.26033638, 0.40180207, 0.01017857, 0.48739217, 0.27398269, 0.21604451, 0.45160373, 0.19142269, 0.12730742, + 0.28658962, 0.13715010, 0.29145664, 0.19553860, 0.35568569 + }; // Leaf regression basis - output.omega << 0.97801674, 0.3707159, - 0.34045661, 0.1312134, - 0.20528387, 0.5614470, - 0.76230322, 0.2276504, - 0.63244655, 0.9029984, - 0.61225851, 0.7448547, - 0.40492125, 0.2549813, - 0.33112223, 0.5295535, - 0.86917047, 0.5584614, - 0.58444831, 0.2365117; + output.omega = std::vector { + 0.9780167, 0.3404566, 0.2052839, 0.7623032, 0.6324466, 0.6122585, 0.4049213, 0.3311222, 0.8691705, + 0.5844483, 0.3707159, 0.1312134, 0.5614470, 0.2276504, 0.9029984, 0.7448547, 0.2549813, 0.5295535, + 0.5584614, 0.2365117 + }; // Outcome output.outcome << 2.158854445, 1.175387297, 0.40481061, 1.751578365, 0.299641379, 0.347249942, 0.546179903, 1.164750138, 3.389946886, -0.605464414; // Random effects regression basis (i.e. constant, intercept-only RFX model) - output.rfx_basis << 1, 1, 1, 1, 1, 1, 1, 1, 1, 1; + output.rfx_basis = std::vector {1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; // Random effects group labels for (int i = 0; i < output.n/2; i++) { @@ -131,121 +122,97 @@ TestDataset LoadMediumDatasetUnivariateBasis() { output.outcome.resize(output.n); // Covariates - output.covariates << 0.766969853, 0.83894646, 0.63649772, 0.6747788934, 0.27398269, - 0.634970996, 0.15237997, 0.3800786, 0.6457891271, 0.21604451, - 0.229598754, 0.12461481, 0.81407372, 0.364336529, 0.45160373, - 0.741084778, 0.53356288, 0.58940162, 0.9995219493, 0.19142269, - 0.618177813, 0.88876378, 0.51174404, 0.8827708189, 0.12730742, - 0.858657839, 0.9271676, 0.5115294, 0.67865624, 0.28658962, - 0.719224842, 0.0546961, 0.42850897, 0.260336376, 0.1371501, - 0.747422328, 0.87172033, 0.98791964, 0.4018020707, 0.29145664, - 0.3158837, 0.39253551, 0.83610831, 0.0101785748, 0.1955386, - 0.419554105, 0.5586495, 0.19908607, 0.4873921743, 0.35568569, - 0.012786428, 0.46925501, 0.25363201, 0.3429851863, 0.2071495, - 0.887479904, 0.66166194, 0.31100105, 0.2895678403, 0.00117005, - 0.147758652, 0.14108789, 0.0361254, 0.4790630946, 0.47336526, - 0.899947367, 0.03730855, 0.33408769, 0.368503517, 0.30600202, - 0.527616998, 0.22344076, 0.20325828, 0.9296060419, 0.34518043, - 0.947085596, 0.85906392, 0.35535464, 0.529360628, 0.8781696, - 0.716097994, 0.9149628, 0.11689428, 0.1157865208, 0.31602707, - 0.433331308, 0.53848417, 0.34146036, 0.4967994317, 0.12822296, - 0.420861259, 0.28802486, 0.62324752, 0.2045601751, 0.06909585, - 0.275279159, 0.69079999, 0.29498051, 0.0082852058, 0.45247107, - 0.909681016, 0.35067747, 0.66813255, 0.3866910117, 0.65315347, - 0.828031845, 0.74096924, 0.33982958, 0.0009472317, 0.65103292, - 0.261653444, 0.43179244, 0.89632155, 0.8636559783, 0.93461464, - 0.209384357, 0.12561389, 0.69809409, 0.4752417156, 0.34963379, - 0.737655852, 0.42078584, 0.09970929, 0.5218528947, 0.36737846, - 0.975034732, 0.69977514, 0.33918481, 0.5443784453, 0.35411297, - 0.053533786, 0.98021485, 0.71035393, 0.189234901, 0.73372176, - 0.364139644, 0.47595789, 0.24620073, 0.4284725219, 0.46145259, - 0.696115067, 0.18095114, 0.66919045, 0.9517078404, 0.31686943, - 0.920878008, 0.89758374, 0.21445324, 0.5666448742, 0.29554824, - 0.397853079, 0.12019741, 0.10775046, 0.0799620333, 0.20065807, - 0.322087545, 0.68342919, 0.29873607, 0.0044371644, 0.66733723, - 0.661407114, 0.0558764, 0.10688295, 0.067841246, 0.52254161, - 0.593253554, 0.40498486, 0.97342655, 0.1917967587, 0.2078643, - 0.392762915, 0.91608107, 0.98894976, 0.3599016496, 0.70576753, - 0.758995247, 0.19899099, 0.95978035, 0.8000916124, 0.8356055, - 0.105617762, 0.12135206, 0.47523114, 0.3594282658, 0.71053726, - 0.754330984, 0.803395, 0.11297253, 0.5072350584, 0.05109695, - 0.410083859, 0.13842349, 0.3671543, 0.262290115, 0.76582706, - 0.498883172, 0.52094766, 0.23674406, 0.8919167451, 0.26313017, - 0.315790046, 0.57934811, 0.96794023, 0.7292640421, 0.63874656, - 0.969918807, 0.86839672, 0.17867962, 0.797609952, 0.3123159, - 0.291589217, 0.37982099, 0.92081884, 0.3760313739, 0.30599535, - 0.874146047, 0.64472863, 0.74944373, 0.0179410274, 0.06637048, - 0.006168369, 0.36819005, 0.48640614, 0.5182905369, 0.37514676, - 0.018794786, 0.50404546, 0.30706335, 0.239409535, 0.78368968, - 0.218041312, 0.08232156, 0.910968, 0.236348928, 0.08734924, - 0.240712896, 0.81851635, 0.75910757, 0.7666831033, 0.51030368, - 0.32422135, 0.37234399, 0.4268269, 0.0688136201, 0.52522145, - 0.737050103, 0.55333162, 0.35681609, 0.5527229193, 0.45528166, - 0.666105454, 0.44928217, 0.93068357, 0.2682658806, 0.47992145, - 0.072705164, 0.24379538, 0.36250275, 0.2693803106, 0.88583253, - 0.393483048, 0.7180344, 0.88936403, 0.9690254654, 0.41720031, - 0.726532397, 0.15675097, 0.14675637, 0.973136256, 0.86701643, - 0.206543021, 0.70612692, 0.9923119, 0.1270776591, 0.43317344, - 0.392393596, 0.6581254, 0.51121301, 0.8005079071, 0.16056554, - 0.326374607, 0.48817642, 0.68630408, 0.9265561129, 0.48683193, - 0.761818521, 0.71751337, 0.83854992, 0.134206275, 0.25700676, - 0.930924999, 0.37469277, 0.42861545, 0.7379696709, 0.9670993, - 0.601101112, 0.56631699, 0.85690728, 0.0792362478, 0.23640603, - 0.294070227, 0.02818223, 0.83060893, 0.8203584203, 0.17647972, - 0.393978659, 0.88639966, 0.80788018, 0.4202279691, 0.75344798, - 0.381183787, 0.98751161, 0.13933232, 0.5427466533, 0.15809025, - 0.203872876, 0.31032719, 0.53000948, 0.6001499062, 0.43581315, - 0.355075927, 0.10865708, 0.21823445, 0.5707600345, 0.84459087, - 0.415892882, 0.09056941, 0.85957968, 0.9296874236, 0.39317951, - 0.885163931, 0.60617414, 0.22888755, 0.9225545505, 0.41601782, - 0.803631177, 0.63855664, 0.4968153, 0.4970232591, 0.28230652, - 0.755692566, 0.36382158, 0.31492054, 0.9853899847, 0.45864754, - 0.761099141, 0.88094342, 0.82542666, 0.977985516, 0.5416208, - 0.536037115, 0.19298885, 0.67674639, 0.213044832, 0.29409245, - 0.050087478, 0.56597845, 0.22309031, 0.7668617836, 0.02385271, - 0.847882026, 0.86580035, 0.8381724, 0.618777399, 0.4707389, - 0.280194086, 0.95490103, 0.27399251, 0.5894525715, 0.17181438, - 0.261382768, 0.96124295, 0.33737123, 0.3545607659, 0.36367031, - 0.465759262, 0.17167592, 0.87114988, 0.4175856721, 0.16020522, - 0.982323635, 0.30892377, 0.96513595, 0.376671114, 0.9411435, - 0.851789546, 0.42260807, 0.37396782, 0.0759502219, 0.41219659, - 0.23932738, 0.70124641, 0.08544481, 0.8599137105, 0.35298377, - 0.985171556, 0.48493665, 0.92919919, 0.3128095574, 0.84388465, - 0.936608667, 0.70159722, 0.23570122, 0.5124408882, 0.99478731, - 0.328337863, 0.83252833, 0.29078719, 0.7531193637, 0.49378383, - 0.504403078, 0.72845174, 0.12801659, 0.5383322216, 0.12559066, - 0.906952623, 0.36801267, 0.13168735, 0.9791060984, 0.14008791, - 0.454210506, 0.67248289, 0.4041049, 0.234963659, 0.92138674, - 0.499037576, 0.7534805, 0.4168877, 0.6275620307, 0.24189188, - 0.707788941, 0.91990553, 0.56701198, 0.1408275496, 0.80566006, - 0.694437274, 0.69339343, 0.42296251, 0.8271595608, 0.53699966, - 0.447118821, 0.97512181, 0.16431204, 0.3697280197, 0.38753206, - 0.885936489, 0.94468978, 0.48918779, 0.3676202064, 0.06938232, - 0.593980148, 0.28140352, 0.27760537, 0.2819242389, 0.8730862, - 0.04248501, 0.45279893, 0.69760642, 0.0949480394, 0.42568701, - 0.35842742, 0.68098838, 0.82745029, 0.5315801166, 0.31104918, - 0.724621041, 0.28763999, 0.48743089, 0.8648093319, 0.93792148, - 0.961828358, 0.5548953, 0.7250596, 0.249875583, 0.90661302, - 0.251438316, 0.86021024, 0.65037498, 0.209739062, 0.07886205, - 0.699615913, 0.12223695, 0.20393331, 0.6357937951, 0.81502268, - 0.391076967, 0.25143855, 0.16091307, 0.6037441837, 0.50651534, - 0.343597198, 0.82570727, 0.62455707, 0.6284155636, 0.17288776, - 0.451352309, 0.29346835, 0.12641623, 0.1194773833, 0.88849468; + output.covariates = std::vector { + 0.7669698530, 0.6349709960, 0.2295987540, 0.7410847780, 0.6181778130, 0.8586578390, 0.7192248420, + 0.7474223280, 0.3158837000, 0.4195541050, 0.0127864280, 0.8874799040, 0.1477586520, 0.8999473670, + 0.5276169980, 0.9470855960, 0.7160979940, 0.4333313080, 0.4208612590, 0.2752791590, 0.9096810160, + 0.8280318450, 0.2616534440, 0.2093843570, 0.7376558520, 0.9750347320, 0.0535337860, 0.3641396440, + 0.6961150670, 0.9208780080, 0.3978530790, 0.3220875450, 0.6614071140, 0.5932535540, 0.3927629150, + 0.7589952470, 0.1056177620, 0.7543309840, 0.4100838590, 0.4988831720, 0.3157900460, 0.9699188070, + 0.2915892170, 0.8741460470, 0.0061683690, 0.0187947860, 0.2180413120, 0.2407128960, 0.3242213500, + 0.7370501030, 0.6661054540, 0.0727051640, 0.3934830480, 0.7265323970, 0.2065430210, 0.3923935960, + 0.3263746070, 0.7618185210, 0.9309249990, 0.6011011120, 0.2940702270, 0.3939786590, 0.3811837870, + 0.2038728760, 0.3550759270, 0.4158928820, 0.8851639310, 0.8036311770, 0.7556925660, 0.7610991410, + 0.5360371150, 0.0500874780, 0.8478820260, 0.2801940860, 0.2613827680, 0.4657592620, 0.9823236350, + 0.8517895460, 0.2393273800, 0.9851715560, 0.9366086670, 0.3283378630, 0.5044030780, 0.9069526230, + 0.4542105060, 0.4990375760, 0.7077889410, 0.6944372740, 0.4471188210, 0.8859364890, 0.5939801480, + 0.0424850100, 0.3584274200, 0.7246210410, 0.9618283580, 0.2514383160, 0.6996159130, 0.3910769670, + 0.3435971980, 0.4513523090, 0.8389464600, 0.1523799700, 0.1246148100, 0.5335628800, 0.8887637800, + 0.9271676000, 0.0546961000, 0.8717203300, 0.3925355100, 0.5586495000, 0.4692550100, 0.6616619400, + 0.1410878900, 0.0373085500, 0.2234407600, 0.8590639200, 0.9149628000, 0.5384841700, 0.2880248600, + 0.6907999900, 0.3506774700, 0.7409692400, 0.4317924400, 0.1256138900, 0.4207858400, 0.6997751400, + 0.9802148500, 0.4759578900, 0.1809511400, 0.8975837400, 0.1201974100, 0.6834291900, 0.0558764000, + 0.4049848600, 0.9160810700, 0.1989909900, 0.1213520600, 0.8033950000, 0.1384234900, 0.5209476600, + 0.5793481100, 0.8683967200, 0.3798209900, 0.6447286300, 0.3681900500, 0.5040454600, 0.0823215600, + 0.8185163500, 0.3723439900, 0.5533316200, 0.4492821700, 0.2437953800, 0.7180344000, 0.1567509700, + 0.7061269200, 0.6581254000, 0.4881764200, 0.7175133700, 0.3746927700, 0.5663169900, 0.0281822300, + 0.8863996600, 0.9875116100, 0.3103271900, 0.1086570800, 0.0905694100, 0.6061741400, 0.6385566400, + 0.3638215800, 0.8809434200, 0.1929888500, 0.5659784500, 0.8658003500, 0.9549010300, 0.9612429500, + 0.1716759200, 0.3089237700, 0.4226080700, 0.7012464100, 0.4849366500, 0.7015972200, 0.8325283300, + 0.7284517400, 0.3680126700, 0.6724828900, 0.7534805000, 0.9199055300, 0.6933934300, 0.9751218100, + 0.9446897800, 0.2814035200, 0.4527989300, 0.6809883800, 0.2876399900, 0.5548953000, 0.8602102400, + 0.1222369500, 0.2514385500, 0.8257072700, 0.2934683500, 0.6364977200, 0.3800786000, 0.8140737200, + 0.5894016200, 0.5117440400, 0.5115294000, 0.4285089700, 0.9879196400, 0.8361083100, 0.1990860700, + 0.2536320100, 0.3110010500, 0.0361254000, 0.3340876900, 0.2032582800, 0.3553546400, 0.1168942800, + 0.3414603600, 0.6232475200, 0.2949805100, 0.6681325500, 0.3398295800, 0.8963215500, 0.6980940900, + 0.0997092900, 0.3391848100, 0.7103539300, 0.2462007300, 0.6691904500, 0.2144532400, 0.1077504600, + 0.2987360700, 0.1068829500, 0.9734265500, 0.9889497600, 0.9597803500, 0.4752311400, 0.1129725300, + 0.3671543000, 0.2367440600, 0.9679402300, 0.1786796200, 0.9208188400, 0.7494437300, 0.4864061400, + 0.3070633500, 0.9109680000, 0.7591075700, 0.4268269000, 0.3568160900, 0.9306835700, 0.3625027500, + 0.8893640300, 0.1467563700, 0.9923119000, 0.5112130100, 0.6863040800, 0.8385499200, 0.4286154500, + 0.8569072800, 0.8306089300, 0.8078801800, 0.1393323200, 0.5300094800, 0.2182344500, 0.8595796800, + 0.2288875500, 0.4968153000, 0.3149205400, 0.8254266600, 0.6767463900, 0.2230903100, 0.8381724000, + 0.2739925100, 0.3373712300, 0.8711498800, 0.9651359500, 0.3739678200, 0.0854448100, 0.9291991900, + 0.2357012200, 0.2907871900, 0.1280165900, 0.1316873500, 0.4041049000, 0.4168877000, 0.5670119800, + 0.4229625100, 0.1643120400, 0.4891877900, 0.2776053700, 0.6976064200, 0.8274502900, 0.4874308900, + 0.7250596000, 0.6503749800, 0.2039333100, 0.1609130700, 0.6245570700, 0.1264162300, 0.6747788934, + 0.6457891271, 0.3643365290, 0.9995219493, 0.8827708189, 0.6786562400, 0.2603363760, 0.4018020707, + 0.0101785748, 0.4873921743, 0.3429851863, 0.2895678403, 0.4790630946, 0.3685035170, 0.9296060419, + 0.5293606280, 0.1157865208, 0.4967994317, 0.2045601751, 0.0082852058, 0.3866910117, 0.0009472317, + 0.8636559783, 0.4752417156, 0.5218528947, 0.5443784453, 0.1892349010, 0.4284725219, 0.9517078404, + 0.5666448742, 0.0799620333, 0.0044371644, 0.0678412460, 0.1917967587, 0.3599016496, 0.8000916124, + 0.3594282658, 0.5072350584, 0.2622901150, 0.8919167451, 0.7292640421, 0.7976099520, 0.3760313739, + 0.0179410274, 0.5182905369, 0.2394095350, 0.2363489280, 0.7666831033, 0.0688136201, 0.5527229193, + 0.2682658806, 0.2693803106, 0.9690254654, 0.9731362560, 0.1270776591, 0.8005079071, 0.9265561129, + 0.1342062750, 0.7379696709, 0.0792362478, 0.8203584203, 0.4202279691, 0.5427466533, 0.6001499062, + 0.5707600345, 0.9296874236, 0.9225545505, 0.4970232591, 0.9853899847, 0.9779855160, 0.2130448320, + 0.7668617836, 0.6187773990, 0.5894525715, 0.3545607659, 0.4175856721, 0.3766711140, 0.0759502219, + 0.8599137105, 0.3128095574, 0.5124408882, 0.7531193637, 0.5383322216, 0.9791060984, 0.2349636590, + 0.6275620307, 0.1408275496, 0.8271595608, 0.3697280197, 0.3676202064, 0.2819242389, 0.0949480394, + 0.5315801166, 0.8648093319, 0.2498755830, 0.2097390620, 0.6357937951, 0.6037441837, 0.6284155636, + 0.1194773833, 0.2739826900, 0.2160445100, 0.4516037300, 0.1914226900, 0.1273074200, 0.2865896200, + 0.1371501000, 0.2914566400, 0.1955386000, 0.3556856900, 0.2071495000, 0.0011700500, 0.4733652600, + 0.3060020200, 0.3451804300, 0.8781696000, 0.3160270700, 0.1282229600, 0.0690958500, 0.4524710700, + 0.6531534700, 0.6510329200, 0.9346146400, 0.3496337900, 0.3673784600, 0.3541129700, 0.7337217600, + 0.4614525900, 0.3168694300, 0.2955482400, 0.2006580700, 0.6673372300, 0.5225416100, 0.2078643000, + 0.7057675300, 0.8356055000, 0.7105372600, 0.0510969500, 0.7658270600, 0.2631301700, 0.6387465600, + 0.3123159000, 0.3059953500, 0.0663704800, 0.3751467600, 0.7836896800, 0.0873492400, 0.5103036800, + 0.5252214500, 0.4552816600, 0.4799214500, 0.8858325300, 0.4172003100, 0.8670164300, 0.4331734400, + 0.1605655400, 0.4868319300, 0.2570067600, 0.9670993000, 0.2364060300, 0.1764797200, 0.7534479800, + 0.1580902500, 0.4358131500, 0.8445908700, 0.3931795100, 0.4160178200, 0.2823065200, 0.4586475400, + 0.5416208000, 0.2940924500, 0.0238527100, 0.4707389000, 0.1718143800, 0.3636703100, 0.1602052200, + 0.9411435000, 0.4121965900, 0.3529837700, 0.8438846500, 0.9947873100, 0.4937838300, 0.1255906600, + 0.1400879100, 0.9213867400, 0.2418918800, 0.8056600600, 0.5369996600, 0.3875320600, 0.0693823200, + 0.8730862000, 0.4256870100, 0.3110491800, 0.9379214800, 0.9066130200, 0.0788620500, 0.8150226800, + 0.5065153400, 0.1728877600, 0.8884946800 + }; // Leaf regression basis - output.omega << 0.97801674, 0.34045661, 0.20528387, 0.76230322, 0.63244655, 0.61225851, 0.40492125, 0.33112223, - 0.86917047, 0.58444831, 0.33316433, 0.62217709, 0.96820668, 0.20778425, 0.23764591, 0.94193115, - 0.03869153, 0.60847765, 0.51535811, 0.81554404, 0.78515289, 0.23337815, 0.16730957, 0.02168331, - 0.08699654, 0.34067049, 0.93141264, 0.03679176, 0.4364772, 0.2644173, 0.23717182, 0.59084776, - 0.63438143, 0.57132227, 0.17568721, 0.15552373, 0.8625478, 0.02466334, 0.47269628, 0.97782225, - 0.90593388, 0.82272111, 0.67374992, 0.47619752, 0.5276532, 0.75182919, 0.09559243, 0.5126907, - 0.45892102, 0.11357212, 0.77861167, 0.78424907, 0.84693988, 0.38814934, 0.01010333, 0.10064384, - 0.68664865, 0.1264298, 0.14314708, 0.62679815, 0.71101772, 0.43504811, 0.8868721, 0.95098048, - 0.38291537, 0.71337451, 0.12109764, 0.68943347, 0.89878588, 0.67524475, 0.95549402, 0.58758459, - 0.68558459, 0.16794963, 0.23680754, 0.40289479, 0.98291039, 0.87276966, 0.76995475, 0.55282963, - 0.12448394, 0.5479543, 0.8718802, 0.14515363, 0.71311006, 0.39196408, 0.94504373, 0.44020353, - 0.24090674, 0.52675625, 0.86674581, 0.90576332, 0.09167602, 0.74795585, 0.26901811, 0.544173, - 0.03336554, 0.8314331, 0.27185696, 0.83434459; + output.omega = std::vector { + 0.97801674, 0.34045661, 0.20528387, 0.76230322, 0.63244655, 0.61225851, 0.40492125, 0.33112223, + 0.86917047, 0.58444831, 0.33316433, 0.62217709, 0.96820668, 0.20778425, 0.23764591, 0.94193115, + 0.03869153, 0.60847765, 0.51535811, 0.81554404, 0.78515289, 0.23337815, 0.16730957, 0.02168331, + 0.08699654, 0.34067049, 0.93141264, 0.03679176, 0.4364772, 0.2644173, 0.23717182, 0.59084776, + 0.63438143, 0.57132227, 0.17568721, 0.15552373, 0.8625478, 0.02466334, 0.47269628, 0.97782225, + 0.90593388, 0.82272111, 0.67374992, 0.47619752, 0.5276532, 0.75182919, 0.09559243, 0.5126907, + 0.45892102, 0.11357212, 0.77861167, 0.78424907, 0.84693988, 0.38814934, 0.01010333, 0.10064384, + 0.68664865, 0.1264298, 0.14314708, 0.62679815, 0.71101772, 0.43504811, 0.8868721, 0.95098048, + 0.38291537, 0.71337451, 0.12109764, 0.68943347, 0.89878588, 0.67524475, 0.95549402, 0.58758459, + 0.68558459, 0.16794963, 0.23680754, 0.40289479, 0.98291039, 0.87276966, 0.76995475, 0.55282963, + 0.12448394, 0.5479543, 0.8718802, 0.14515363, 0.71311006, 0.39196408, 0.94504373, 0.44020353, + 0.24090674, 0.52675625, 0.86674581, 0.90576332, 0.09167602, 0.74795585, 0.26901811, 0.544173, + 0.03336554, 0.8314331, 0.27185696, 0.83434459 + }; // Outcome output.outcome << 2.158854445, 1.175387297, 0.40481061, 1.751578365, 0.299641379, 0.347249942, 0.546179903, @@ -264,12 +231,14 @@ TestDataset LoadMediumDatasetUnivariateBasis() { -1.971537882, 0.962010578, 1.552073631, 0.459464684, -0.149159276, 0.203079262, -0.453721958, 2.152977755, 0.948865461; // Random effects regression basis (i.e. constant, intercept-only RFX model) - output.rfx_basis << 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1; + output.rfx_basis = std::vector{ + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 + }; // Random effects group labels for (int i = 0; i < output.n/2; i++) { diff --git a/test/testutils.h b/test/testutils.h index 8bbcda01..9406d8ec 100644 --- a/test/testutils.h +++ b/test/testutils.h @@ -6,6 +6,7 @@ #define STOCHTREE_TESTUTILS_H_ #include +#include #include #include @@ -14,16 +15,16 @@ namespace StochTree { namespace TestUtils { struct TestDataset { - Eigen::Matrix covariates; - Eigen::Matrix omega; - Eigen::Matrix rfx_basis; + std::vector covariates; + std::vector omega; + std::vector rfx_basis; Eigen::VectorXd outcome; std::vector rfx_groups; int n; int x_cols; int omega_cols; int rfx_basis_cols; - bool row_major{true}; + bool row_major{false}; }; /*! Creates a small dataset (10 observations) */ From d505bd8a4e93707c67f23daad6707cf2ce92d065 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 29 Apr 2024 04:14:58 -0500 Subject: [PATCH 4/8] Initial C++ wrapper around BCF sampler. Segfaulting. Debug in the morning --- debug/api_debug.cpp | 10 +- debug/bcf_debug.cpp | 230 +++++++------------- include/stochtree/cpp_api.h | 413 ++++++++++++++++++++++++++++++++++++ include/stochtree/meta.h | 6 + 4 files changed, 503 insertions(+), 156 deletions(-) create mode 100644 include/stochtree/cpp_api.h diff --git a/debug/api_debug.cpp b/debug/api_debug.cpp index ba335085..8a6cffc7 100644 --- a/debug/api_debug.cpp +++ b/debug/api_debug.cpp @@ -20,11 +20,11 @@ namespace StochTree{ -enum ForestLeafModel { - kConstant, - kUnivariateRegression, - kMultivariateRegression -}; +// enum ForestLeafModel { +// kConstant, +// kUnivariateRegression, +// kMultivariateRegression +// }; void GenerateRandomData(std::vector& covariates, std::vector& basis, std::vector& outcome, std::vector& rfx_basis, std::vector& rfx_groups, int n, int x_cols, int omega_cols, int y_cols, int rfx_basis_cols) { std::mt19937 gen(101); diff --git a/debug/bcf_debug.cpp b/debug/bcf_debug.cpp index 1ef35f2d..52c44436 100644 --- a/debug/bcf_debug.cpp +++ b/debug/bcf_debug.cpp @@ -1,5 +1,6 @@ /*! Copyright (c) 2024 stochtree authors*/ #include +#include #include #include #include @@ -21,17 +22,11 @@ namespace StochTree{ -enum ForestLeafModel { - kConstant, - kUnivariateRegression, - kMultivariateRegression -}; - -double calibrate_lambda(ForestDataset& covariates, ColumnVector& residual, double nu, double q) { +double calibrate_lambda(std::vector& covariates, std::vector& residual, double nu, double q, int num_rows, int x_cols) { // Linear model of residual ~ covariates - double n = static_cast(covariates.NumObservations()); - Eigen::MatrixXd X = covariates.GetCovariates(); - Eigen::VectorXd y = residual.GetData(); + double n = static_cast(residual.size()); + Eigen::Map> X(covariates.data(), num_rows, x_cols); + Eigen::Map> y(residual.data(), num_rows); Eigen::VectorXd beta = (X.transpose() * X).inverse() * (X.transpose() * y); double sum_sq_resid = (y - X * beta).transpose() * (y - X * beta); double sigma_hat = sum_sq_resid / n; @@ -162,14 +157,14 @@ void GenerateRandomData(std::vector& covariates, std::vector& pr } } -void OutcomeOffsetScale(ColumnVector& residual, double& outcome_offset, double& outcome_scale) { - data_size_t n = residual.NumRows(); +void OutcomeOffsetScale(std::vector& residual, double& outcome_offset, double& outcome_scale) { + data_size_t n = residual.size(); double outcome_val = 0.0; double outcome_sum = 0.0; double outcome_sum_squares = 0.0; double var_y = 0.0; for (data_size_t i = 0; i < n; i++){ - outcome_val = residual.GetElement(i); + outcome_val = residual.at(i); outcome_sum += outcome_val; outcome_sum_squares += std::pow(outcome_val, 2.0); } @@ -178,46 +173,46 @@ void OutcomeOffsetScale(ColumnVector& residual, double& outcome_offset, double& outcome_offset = outcome_sum / static_cast(n); double previous_residual; for (data_size_t i = 0; i < n; i++){ - previous_residual = residual.GetElement(i); - residual.SetElement(i, (previous_residual - outcome_offset) / outcome_scale); + previous_residual = residual.at(i); + residual.at(i) = (previous_residual - outcome_offset) / outcome_scale; } } -void sampleGFR(ForestTracker& tracker, TreePrior& tree_prior, ForestContainer& forest_samples, ForestDataset& dataset, - ColumnVector& residual, std::mt19937& rng, std::vector& feature_types, std::vector& var_weights_vector, - ForestLeafModel leaf_model_type, Eigen::MatrixXd& leaf_scale_matrix, double global_variance, double leaf_scale, int cutpoint_grid_size) { - if (leaf_model_type == ForestLeafModel::kConstant) { - GaussianConstantLeafModel leaf_model = GaussianConstantLeafModel(leaf_scale); - GFRForestSampler sampler = GFRForestSampler(cutpoint_grid_size); - sampler.SampleOneIter(tracker, forest_samples, leaf_model, dataset, residual, tree_prior, rng, var_weights_vector, global_variance, feature_types); - } else if (leaf_model_type == ForestLeafModel::kUnivariateRegression) { - GaussianUnivariateRegressionLeafModel leaf_model = GaussianUnivariateRegressionLeafModel(leaf_scale); - GFRForestSampler sampler = GFRForestSampler(cutpoint_grid_size); - sampler.SampleOneIter(tracker, forest_samples, leaf_model, dataset, residual, tree_prior, rng, var_weights_vector, global_variance, feature_types); - } else if (leaf_model_type == ForestLeafModel::kMultivariateRegression) { - GaussianMultivariateRegressionLeafModel leaf_model = GaussianMultivariateRegressionLeafModel(leaf_scale_matrix); - GFRForestSampler sampler = GFRForestSampler(cutpoint_grid_size); - sampler.SampleOneIter(tracker, forest_samples, leaf_model, dataset, residual, tree_prior, rng, var_weights_vector, global_variance, feature_types); - } -} - -void sampleMCMC(ForestTracker& tracker, TreePrior& tree_prior, ForestContainer& forest_samples, ForestDataset& dataset, - ColumnVector& residual, std::mt19937& rng, std::vector& feature_types, std::vector& var_weights_vector, - ForestLeafModel leaf_model_type, Eigen::MatrixXd& leaf_scale_matrix, double global_variance, double leaf_scale, int cutpoint_grid_size) { - if (leaf_model_type == ForestLeafModel::kConstant) { - GaussianConstantLeafModel leaf_model = GaussianConstantLeafModel(leaf_scale); - MCMCForestSampler sampler = MCMCForestSampler(); - sampler.SampleOneIter(tracker, forest_samples, leaf_model, dataset, residual, tree_prior, rng, var_weights_vector, global_variance); - } else if (leaf_model_type == ForestLeafModel::kUnivariateRegression) { - GaussianUnivariateRegressionLeafModel leaf_model = GaussianUnivariateRegressionLeafModel(leaf_scale); - MCMCForestSampler sampler = MCMCForestSampler(); - sampler.SampleOneIter(tracker, forest_samples, leaf_model, dataset, residual, tree_prior, rng, var_weights_vector, global_variance); - } else if (leaf_model_type == ForestLeafModel::kMultivariateRegression) { - GaussianMultivariateRegressionLeafModel leaf_model = GaussianMultivariateRegressionLeafModel(leaf_scale_matrix); - MCMCForestSampler sampler = MCMCForestSampler(); - sampler.SampleOneIter(tracker, forest_samples, leaf_model, dataset, residual, tree_prior, rng, var_weights_vector, global_variance); - } -} +// void sampleGFR(ForestTracker& tracker, TreePrior& tree_prior, ForestContainer& forest_samples, ForestDataset& dataset, +// ColumnVector& residual, std::mt19937& rng, std::vector& feature_types, std::vector& var_weights_vector, +// ForestLeafModel leaf_model_type, Eigen::MatrixXd& leaf_scale_matrix, double global_variance, double leaf_scale, int cutpoint_grid_size) { +// if (leaf_model_type == ForestLeafModel::kConstant) { +// GaussianConstantLeafModel leaf_model = GaussianConstantLeafModel(leaf_scale); +// GFRForestSampler sampler = GFRForestSampler(cutpoint_grid_size); +// sampler.SampleOneIter(tracker, forest_samples, leaf_model, dataset, residual, tree_prior, rng, var_weights_vector, global_variance, feature_types); +// } else if (leaf_model_type == ForestLeafModel::kUnivariateRegression) { +// GaussianUnivariateRegressionLeafModel leaf_model = GaussianUnivariateRegressionLeafModel(leaf_scale); +// GFRForestSampler sampler = GFRForestSampler(cutpoint_grid_size); +// sampler.SampleOneIter(tracker, forest_samples, leaf_model, dataset, residual, tree_prior, rng, var_weights_vector, global_variance, feature_types); +// } else if (leaf_model_type == ForestLeafModel::kMultivariateRegression) { +// GaussianMultivariateRegressionLeafModel leaf_model = GaussianMultivariateRegressionLeafModel(leaf_scale_matrix); +// GFRForestSampler sampler = GFRForestSampler(cutpoint_grid_size); +// sampler.SampleOneIter(tracker, forest_samples, leaf_model, dataset, residual, tree_prior, rng, var_weights_vector, global_variance, feature_types); +// } +// } + +// void sampleMCMC(ForestTracker& tracker, TreePrior& tree_prior, ForestContainer& forest_samples, ForestDataset& dataset, +// ColumnVector& residual, std::mt19937& rng, std::vector& feature_types, std::vector& var_weights_vector, +// ForestLeafModel leaf_model_type, Eigen::MatrixXd& leaf_scale_matrix, double global_variance, double leaf_scale, int cutpoint_grid_size) { +// if (leaf_model_type == ForestLeafModel::kConstant) { +// GaussianConstantLeafModel leaf_model = GaussianConstantLeafModel(leaf_scale); +// MCMCForestSampler sampler = MCMCForestSampler(); +// sampler.SampleOneIter(tracker, forest_samples, leaf_model, dataset, residual, tree_prior, rng, var_weights_vector, global_variance); +// } else if (leaf_model_type == ForestLeafModel::kUnivariateRegression) { +// GaussianUnivariateRegressionLeafModel leaf_model = GaussianUnivariateRegressionLeafModel(leaf_scale); +// MCMCForestSampler sampler = MCMCForestSampler(); +// sampler.SampleOneIter(tracker, forest_samples, leaf_model, dataset, residual, tree_prior, rng, var_weights_vector, global_variance); +// } else if (leaf_model_type == ForestLeafModel::kMultivariateRegression) { +// GaussianMultivariateRegressionLeafModel leaf_model = GaussianMultivariateRegressionLeafModel(leaf_scale_matrix); +// MCMCForestSampler sampler = MCMCForestSampler(); +// sampler.SampleOneIter(tracker, forest_samples, leaf_model, dataset, residual, tree_prior, rng, var_weights_vector, global_variance); +// } +// } void RunAPI() { // Data dimensions @@ -245,30 +240,22 @@ void RunAPI() { // Define internal datasets bool row_major = false; - // Construct datasets for training, include pi(x) as a covariate in the prognostic forest - ForestDataset tau_dataset = ForestDataset(); - tau_dataset.AddCovariates(covariates_raw.data(), n, x_cols, row_major); - tau_dataset.AddBasis(treatment_raw.data(), n, 1, row_major); - ForestDataset mu_dataset = ForestDataset(); - mu_dataset.AddCovariates(covariates_pi.data(), n, x_cols+1, row_major); - ColumnVector residual = ColumnVector(outcome_raw.data(), n); - // Center and scale the data double outcome_offset; double outcome_scale; - OutcomeOffsetScale(residual, outcome_offset, outcome_scale); + OutcomeOffsetScale(outcome_raw, outcome_offset, outcome_scale); // Initialize ensembles for prognostic and treatment forests int num_trees_mu = 200; int num_trees_tau = 50; ForestContainer forest_samples_mu = ForestContainer(num_trees_mu, 1, true); ForestContainer forest_samples_tau = ForestContainer(num_trees_tau, 1, false); + forest_samples_mu.InitializeRoot(0.); + forest_samples_tau.InitializeRoot(0.); // Initialize leaf models for mu and tau forests double leaf_prior_scale_mu = (outcome_scale*outcome_scale)/num_trees_mu; double leaf_prior_scale_tau = (outcome_scale*outcome_scale)/(2*num_trees_tau); - GaussianConstantLeafModel leaf_model_mu = GaussianConstantLeafModel(leaf_prior_scale_mu); - GaussianUnivariateRegressionLeafModel leaf_model_tau = GaussianUnivariateRegressionLeafModel(leaf_prior_scale_tau); // Initialize forest sampling machinery std::vector feature_types_mu(x_cols + 1, FeatureType::kNumeric); @@ -283,110 +270,51 @@ void RunAPI() { double beta_tau = 3.0; int min_samples_leaf_mu = 5; int min_samples_leaf_tau = 5; - int cutpoint_grid_size_mu = 100; - int cutpoint_grid_size_tau = 100; + int cutpoint_grid_size = 100; double a_leaf_mu = 3.; double b_leaf_mu = leaf_prior_scale_mu; double a_leaf_tau = 3.; double b_leaf_tau = leaf_prior_scale_tau; double nu = 3.; - double lamb = calibrate_lambda(tau_dataset, residual, nu, 0.9); - ForestLeafModel leaf_model_type_mu = ForestLeafModel::kConstant; - ForestLeafModel leaf_model_type_tau = ForestLeafModel::kUnivariateRegression; + double lamb = calibrate_lambda(covariates_raw, outcome_raw, nu, 0.9, n, x_cols); + double b1 = 0.5; + double b0 = -0.5; + // ForestLeafModel leaf_model_type_mu = ForestLeafModel::kConstant; + // ForestLeafModel leaf_model_type_tau = ForestLeafModel::kUnivariateRegression; + int num_gfr_samples = 10; + int num_mcmc_samples = 1000; + int num_samples = num_gfr_samples + num_mcmc_samples; - // Set leaf model parameters - double leaf_scale_mu; - double leaf_scale_tau = leaf_prior_scale_tau; - Eigen::MatrixXd leaf_scale_matrix_mu; - Eigen::MatrixXd leaf_scale_matrix_tau; + // // Set leaf model parameters + // double leaf_scale_mu; + // double leaf_scale_tau = leaf_prior_scale_tau; + // Eigen::MatrixXd leaf_scale_matrix_mu; + // Eigen::MatrixXd leaf_scale_matrix_tau; // Set global variance - double global_variance_init = 1.0; - double global_variance; - - // Set variable weights - double const_var_wt_mu = static_cast(1/(x_cols+1)); - std::vector variable_weights_mu(x_cols+1, const_var_wt_mu); - double const_var_wt_tau = static_cast(1/x_cols); - std::vector variable_weights_tau(x_cols, const_var_wt_tau); - - // Initialize tracker and tree prior - ForestTracker mu_tracker = ForestTracker(mu_dataset.GetCovariates(), feature_types_mu, num_trees_mu, n); - ForestTracker tau_tracker = ForestTracker(tau_dataset.GetCovariates(), feature_types_tau, num_trees_tau, n); - TreePrior tree_prior_mu = TreePrior(alpha_mu, beta_mu, min_samples_leaf_mu); - TreePrior tree_prior_tau = TreePrior(alpha_tau, beta_tau, min_samples_leaf_tau); + double sigma2 = 1.0; // Initialize a random number generator std::random_device rd; std::mt19937 rng = std::mt19937(rd()); - - // Initialize variance models - GlobalHomoskedasticVarianceModel global_var_model = GlobalHomoskedasticVarianceModel(); - LeafNodeHomoskedasticVarianceModel leaf_var_model_mu = LeafNodeHomoskedasticVarianceModel(); // Initialize storage for samples of variance - std::vector global_variance_samples{}; - std::vector leaf_variance_samples_mu{}; - - // Run the GFR sampler - int num_gfr_samples = 10; - for (int i = 0; i < num_gfr_samples; i++) { - if (i == 0) { - global_variance = global_variance_init; - leaf_scale_mu = leaf_prior_scale_mu; - } else { - global_variance = global_variance_samples[i-1]; - leaf_scale_mu = leaf_variance_samples_mu[i-1]; - } - - // Sample mu ensemble - sampleGFR(mu_tracker, tree_prior_mu, forest_samples_mu, mu_dataset, residual, rng, feature_types_mu, variable_weights_mu, - leaf_model_type_mu, leaf_scale_matrix_mu, global_variance, leaf_scale_mu, cutpoint_grid_size_mu); - - // Sample leaf node variance - leaf_variance_samples_mu.push_back(leaf_var_model_mu.SampleVarianceParameter(forest_samples_mu.GetEnsemble(i), a_leaf_mu, b_leaf_mu, rng)); - - // Sample global variance - global_variance_samples.push_back(global_var_model.SampleVarianceParameter(residual.GetData(), nu, nu*lamb, rng)); - - // Sample tau ensemble - sampleGFR(tau_tracker, tree_prior_tau, forest_samples_tau, tau_dataset, residual, rng, feature_types_tau, variable_weights_tau, - leaf_model_type_tau, leaf_scale_matrix_tau, global_variance, leaf_scale_tau, cutpoint_grid_size_tau); - - // Sample global variance - global_variance_samples.push_back(global_var_model.SampleVarianceParameter(residual.GetData(), nu, nu*lamb, rng)); - } - - // Run the MCMC sampler - int num_mcmc_samples = 10000; - for (int i = num_gfr_samples; i < num_gfr_samples + num_mcmc_samples; i++) { - if (i == 0) { - global_variance = global_variance_init; - leaf_scale_mu = leaf_prior_scale_mu; - } else { - global_variance = global_variance_samples[i-1]; - leaf_scale_mu = leaf_variance_samples_mu[i-1]; - } - - // Sample mu ensemble - sampleMCMC(mu_tracker, tree_prior_mu, forest_samples_mu, mu_dataset, residual, rng, feature_types_mu, variable_weights_mu, - leaf_model_type_mu, leaf_scale_matrix_mu, global_variance, leaf_scale_mu, cutpoint_grid_size_mu); - - // Sample leaf node variance - leaf_variance_samples_mu.push_back(leaf_var_model_mu.SampleVarianceParameter(forest_samples_mu.GetEnsemble(i), a_leaf_mu, b_leaf_mu, rng)); - - // Sample global variance - global_variance_samples.push_back(global_var_model.SampleVarianceParameter(residual.GetData(), nu, nu*lamb, rng)); - - // Sample tau ensemble - sampleMCMC(tau_tracker, tree_prior_tau, forest_samples_tau, tau_dataset, residual, rng, feature_types_tau, variable_weights_tau, - leaf_model_type_tau, leaf_scale_matrix_tau, global_variance, leaf_scale_tau, cutpoint_grid_size_tau); - - // Sample global variance - global_variance_samples.push_back(global_var_model.SampleVarianceParameter(residual.GetData(), nu, nu*lamb, rng)); - - // Estimatw - } + std::vector global_variance_samples(num_samples); + std::vector leaf_variance_samples_mu(num_samples); + + // Initialize the BCF sampler + BCFModel bcf = BCFModel(); + bcf.LoadTrain(outcome_raw.data(), n, covariates_pi.data(), x_cols+1, + covariates_raw.data(), x_cols, treatment_raw.data(), 1, true); + bcf.ResetGlobalVarSamples(global_variance_samples.data(), num_samples); + bcf.ResetPrognosticLeafVarSamples(leaf_variance_samples_mu.data(), num_samples); + + // Run the BCF sampler + bcf.SampleBCF(&forest_samples_mu, &forest_samples_tau, &rng, cutpoint_grid_size, + leaf_prior_scale_mu, leaf_prior_scale_tau, alpha_mu, alpha_tau, beta_mu, beta_tau, + min_samples_leaf_mu, min_samples_leaf_tau, nu, lamb, a_leaf_mu, a_leaf_tau, b_leaf_mu, b_leaf_tau, + sigma2, num_trees_mu, num_trees_tau, b1, b0, feature_types_mu, feature_types_tau, + num_gfr_samples, 0, num_mcmc_samples, 0.0, 0.0); } } // namespace StochTree diff --git a/include/stochtree/cpp_api.h b/include/stochtree/cpp_api.h new file mode 100644 index 00000000..71debc55 --- /dev/null +++ b/include/stochtree/cpp_api.h @@ -0,0 +1,413 @@ +/*! + * Copyright (c) 2024 stochtree authors. + * + * High-level C++ API for BART and BCF + */ +#ifndef STOCHTREE_CPP_API_H_ +#define STOCHTREE_CPP_API_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace StochTree { + +struct BCFParameters { + int cutpoint_grid_size; + double sigma_leaf_mu; + double sigma_leaf_tau; + Eigen::MatrixXd sigma_leaf_tau_mat; + double alpha_mu; + double alpha_tau; + double beta_mu; + double beta_tau; + int min_samples_leaf_mu; + int min_samples_leaf_tau; + double nu; + double lamb; + double a_leaf_mu; + double a_leaf_tau; + double b_leaf_mu; + double b_leaf_tau; + double sigma2; + int num_trees_mu; + int num_trees_tau; + double b1; + double b0; + int num_gfr; + int num_burnin; + int num_mcmc; + std::vector feature_types_mu; + std::vector feature_types_tau; + double leaf_init_mu; + double leaf_init_tau; + + BCFParameters(int cutpoint_grid_size, double sigma_leaf_mu, double sigma_leaf_tau, + double alpha_mu, double alpha_tau, double beta_mu, double beta_tau, + int min_samples_leaf_mu, int min_samples_leaf_tau, double nu, double lamb, + double a_leaf_mu, double a_leaf_tau, double b_leaf_mu, double b_leaf_tau, + double sigma2, int num_trees_mu, int num_trees_tau, double b1, double b0, + std::vector& feature_types_mu, std::vector& feature_types_tau, + int num_gfr, int num_burnin, int num_mcmc, double leaf_init_mu, double leaf_init_tau) { + cutpoint_grid_size = cutpoint_grid_size; + sigma_leaf_mu = sigma_leaf_mu; + sigma_leaf_tau = sigma_leaf_tau; + alpha_mu = alpha_mu; + alpha_tau = alpha_tau; + beta_mu = beta_mu; + beta_tau = beta_tau; + min_samples_leaf_mu = min_samples_leaf_mu; + min_samples_leaf_tau = min_samples_leaf_tau; + nu = nu; + lamb = lamb; + a_leaf_mu = a_leaf_mu; + a_leaf_tau = a_leaf_tau; + b_leaf_mu = b_leaf_mu; + b_leaf_tau = b_leaf_tau; + sigma2 = sigma2; + num_trees_mu = num_trees_mu; + num_trees_tau = num_trees_tau; + b1 = b1; + b0 = b0; + num_gfr = num_gfr; + num_burnin = num_burnin; + num_mcmc = num_mcmc; + feature_types_mu = feature_types_mu; + feature_types_tau = feature_types_tau; + leaf_init_mu = leaf_init_mu; + leaf_init_tau = leaf_init_tau; + } + BCFParameters(int cutpoint_grid_size, double sigma_leaf_mu, Eigen::MatrixXd& sigma_leaf_tau, + double alpha_mu, double alpha_tau, double beta_mu, double beta_tau, + int min_samples_leaf_mu, int min_samples_leaf_tau, double nu, double lamb, + double a_leaf_mu, double a_leaf_tau, double b_leaf_mu, double b_leaf_tau, + double sigma2, int num_trees_mu, int num_trees_tau, double b1, double b0, + std::vector& feature_types_mu, std::vector& feature_types_tau, + int num_gfr, int num_burnin, int num_mcmc, double leaf_init_mu, double leaf_init_tau) { + cutpoint_grid_size = cutpoint_grid_size; + sigma_leaf_mu = sigma_leaf_mu; + sigma_leaf_tau_mat = sigma_leaf_tau; + alpha_mu = alpha_mu; + alpha_tau = alpha_tau; + beta_mu = beta_mu; + beta_tau = beta_tau; + min_samples_leaf_mu = min_samples_leaf_mu; + min_samples_leaf_tau = min_samples_leaf_tau; + nu = nu; + lamb = lamb; + a_leaf_mu = a_leaf_mu; + a_leaf_tau = a_leaf_tau; + b_leaf_mu = b_leaf_mu; + b_leaf_tau = b_leaf_tau; + sigma2 = sigma2; + num_trees_mu = num_trees_mu; + num_trees_tau = num_trees_tau; + b1 = b1; + b0 = b0; + num_gfr = num_gfr; + num_burnin = num_burnin; + num_mcmc = num_mcmc; + feature_types_mu = feature_types_mu; + feature_types_tau = feature_types_tau; + leaf_init_mu = leaf_init_mu; + leaf_init_tau = leaf_init_tau; + } +}; + +/*! \brief Class that coordinates BCF sampler and returns results */ +template +class BCFModel { + public: + BCFModel(){} + ~BCFModel(){} + void SampleBCF(ForestContainer* forest_samples_mu, ForestContainer* forest_samples_tau, std::mt19937* rng, + int cutpoint_grid_size, double sigma_leaf_mu, double sigma_leaf_tau, + double alpha_mu, double alpha_tau, double beta_mu, double beta_tau, + int min_samples_leaf_mu, int min_samples_leaf_tau, double nu, double lamb, + double a_leaf_mu, double a_leaf_tau, double b_leaf_mu, double b_leaf_tau, + double sigma2, int num_trees_mu, int num_trees_tau, double b1, double b0, + std::vector& feature_types_mu, std::vector& feature_types_tau, + int num_gfr, int num_burnin, int num_mcmc, double leaf_init_mu, double leaf_init_tau) { + BCFParameters params(cutpoint_grid_size, sigma_leaf_mu, sigma_leaf_tau, alpha_mu, alpha_tau, beta_mu, beta_tau, + min_samples_leaf_mu, min_samples_leaf_tau, nu, lamb, a_leaf_mu, a_leaf_tau, b_leaf_mu, b_leaf_tau, + sigma2, num_trees_mu, num_trees_tau, b1, b0, feature_types_mu, feature_types_tau, + num_gfr, num_burnin, num_mcmc, leaf_init_mu, leaf_init_tau); + SampleBCFInternal(forest_samples_mu, forest_samples_tau, rng, params); + } + + void SampleBCF(ForestContainer* forest_samples_mu, ForestContainer* forest_samples_tau, std::mt19937* rng, + int cutpoint_grid_size, double sigma_leaf_mu, Eigen::MatrixXd& sigma_leaf_tau, + double alpha_mu, double alpha_tau, double beta_mu, double beta_tau, + int min_samples_leaf_mu, int min_samples_leaf_tau, double nu, double lamb, + double a_leaf_mu, double a_leaf_tau, double b_leaf_mu, double b_leaf_tau, + double sigma2, int num_trees_mu, int num_trees_tau, double b1, double b0, + std::vector& feature_types_mu, std::vector& feature_types_tau, + int num_gfr, int num_burnin, int num_mcmc, double leaf_init_mu, double leaf_init_tau) { + BCFParameters params(cutpoint_grid_size, sigma_leaf_mu, sigma_leaf_tau, alpha_mu, alpha_tau, beta_mu, beta_tau, + min_samples_leaf_mu, min_samples_leaf_tau, nu, lamb, a_leaf_mu, a_leaf_tau, b_leaf_mu, b_leaf_tau, + sigma2, num_trees_mu, num_trees_tau, b1, b0, feature_types_mu, feature_types_tau, + num_gfr, num_burnin, num_mcmc, leaf_init_mu, leaf_init_tau); + SampleBCFInternal(forest_samples_mu, forest_samples_tau, rng, params); + } + + void LoadTrain(double* residual_data_ptr, int num_rows, double* prognostic_covariate_data_ptr, int num_prognostic_covariates, + double* treatment_covariate_data_ptr, int num_treatment_covariates, double* treatment_data_ptr, + int num_treatment, bool treatment_binary) { + // Residual + residual_train_.LoadData(residual_data_ptr, num_rows); + + // Prognostic term training dataset + forest_dataset_mu_train_.AddCovariates(prognostic_covariate_data_ptr, num_rows, num_prognostic_covariates, false); + + // Treatment term training dataset + forest_dataset_tau_train_.AddCovariates(treatment_covariate_data_ptr, num_rows, num_treatment_covariates, false); + forest_dataset_tau_train_.AddBasis(treatment_data_ptr, num_rows, num_treatment, false); + treatment_dim_ = num_treatment; + treatment_binary_ = treatment_binary; + } + + void LoadTrain(double* residual_data_ptr, int num_rows, double* prognostic_covariate_data_ptr, int num_prognostic_covariates, + double* treatment_covariate_data_ptr, int num_treatment_covariates, double* treatment_data_ptr, + int num_treatment, bool treatment_binary, double* weights_data_ptr) { + // Residual + residual_train_.LoadData(residual_data_ptr, num_rows); + + // Prognostic term training dataset + forest_dataset_mu_train_.AddCovariates(prognostic_covariate_data_ptr, num_rows, num_prognostic_covariates, false); + forest_dataset_mu_train_.AddVarianceWeights(weights_data_ptr, num_rows); + + // Treatment term training dataset + forest_dataset_tau_train_.AddCovariates(treatment_covariate_data_ptr, num_rows, num_treatment_covariates, false); + forest_dataset_tau_train_.AddBasis(treatment_data_ptr, num_rows, num_treatment, false); + forest_dataset_tau_train_.AddVarianceWeights(weights_data_ptr, num_rows); + treatment_dim_ = num_treatment; + treatment_binary_ = treatment_binary; + + has_weights_ = true; + } + + void LoadTest(double* prognostic_covariate_data_ptr, int num_rows, int num_prognostic_covariates, + double* treatment_covariate_data_ptr, int num_treatment_covariates, double* treatment_data_ptr, int num_treatment) { + // Prognostic term training dataset + forest_dataset_mu_test_.AddCovariates(prognostic_covariate_data_ptr, num_rows, num_prognostic_covariates, false); + + // Treatment term training dataset + forest_dataset_tau_test_.AddCovariates(treatment_covariate_data_ptr, num_rows, num_treatment_covariates, false); + forest_dataset_tau_test_.AddBasis(treatment_data_ptr, num_rows, num_treatment, false); + + has_test_ = true; + } + + void ResetGlobalVarSamples(double* data_ptr, int num_samples) { + new (&global_var_samples_) MatrixMap(data_ptr, num_samples, 1); + global_var_random_ = true; + } + + void ResetPrognosticLeafVarSamples(double* data_ptr, int num_samples) { + new (&prognostic_leaf_var_samples_) MatrixMap(data_ptr, num_samples, 1); + prognostic_leaf_var_random_ = true; + } + + void ResetTreatmentLeafVarSamples(double* data_ptr, int num_samples) { + new (&treatment_leaf_var_samples_) MatrixMap(data_ptr, num_samples, 1); + treatment_leaf_var_random_ = true; + } + + void ResetTreatedCodingSamples(double* data_ptr, int num_samples) { + new (&b1_samples_) MatrixMap(data_ptr, num_samples, 1); + } + + void ResetControlCodingSamples(double* data_ptr, int num_samples) { + new (&b0_samples_) MatrixMap(data_ptr, num_samples, 1); + } + + void ResetTrainPredictionSamples(double* muhat_data_ptr, double* tauhat_data_ptr, double* yhat_data_ptr, int num_obs, int num_samples) { + new (&muhat_train_samples_) MatrixMap(muhat_data_ptr, num_obs, num_samples); + new (&tauhat_train_samples_) MatrixMap(tauhat_data_ptr, num_obs, num_samples); + new (&yhat_train_samples_) MatrixMap(yhat_data_ptr, num_obs, num_samples); + } + + void ResetTestPredictionSamples(double* muhat_data_ptr, double* tauhat_data_ptr, double* yhat_data_ptr, int num_obs, int num_samples) { + new (&muhat_test_samples_) MatrixMap(muhat_data_ptr, num_obs, num_samples); + new (&tauhat_test_samples_) MatrixMap(tauhat_data_ptr, num_obs, num_samples); + new (&yhat_test_samples_) MatrixMap(yhat_data_ptr, num_obs, num_samples); + } + + private: + // Details of the model + int treatment_dim_{1}; + bool adaptive_coding_{false}; + bool treatment_binary_{true}; + bool global_var_random_{false}; + bool prognostic_leaf_var_random_{false}; + bool treatment_leaf_var_random_{false}; + bool has_weights_{false}; + bool has_test_{false}; + + // Train and test sets + ColumnVector residual_train_; + ForestDataset forest_dataset_mu_train_; + ForestDataset forest_dataset_mu_test_; + ForestDataset forest_dataset_tau_train_; + ForestDataset forest_dataset_tau_test_; + + // There is no default initializer for Eigen::Map, so we initialize to + // NULL, 1, 1 and reset the map when necessary + MatrixMap global_var_samples_{NULL,1,1}; + MatrixMap prognostic_leaf_var_samples_{NULL,1,1}; + MatrixMap treatment_leaf_var_samples_{NULL,1,1}; + MatrixMap b1_samples_{NULL,1,1}; + MatrixMap b0_samples_{NULL,1,1}; + MatrixMap muhat_train_samples_{NULL,1,1}; + MatrixMap tauhat_train_samples_{NULL,1,1}; + MatrixMap yhat_train_samples_{NULL,1,1}; + MatrixMap muhat_test_samples_{NULL,1,1}; + MatrixMap tauhat_test_samples_{NULL,1,1}; + MatrixMap yhat_test_samples_{NULL,1,1}; + + TauModelType InitializeTauLeafModel(double sigma_leaf, Eigen::MatrixXd& sigma_leaf_mat) { + if constexpr (std::is_same_v) { + return TauModelType(sigma_leaf_mat); + } else { + return TauModelType(sigma_leaf); + } + } + + void SampleBCFInternal(ForestContainer* forest_samples_mu, ForestContainer* forest_samples_tau, std::mt19937* rng, BCFParameters& params) { + // Initialize leaf models for mu and tau forests + GaussianConstantLeafModel leaf_model_mu = GaussianConstantLeafModel(params.sigma_leaf_mu); + TauModelType leaf_model_tau = InitializeTauLeafModel(params.sigma_leaf_tau, params.sigma_leaf_tau_mat); + // TauModelType leaf_model_tau; + // if constexpr (std::is_same_v) { + // leaf_model_tau = TauModelType(params.sigma_leaf_tau_mat); + // } else { + // leaf_model_tau = TauModelType(params.sigma_leaf_tau); + // } + + // Set variable weights + double const_var_wt_mu = static_cast(1/(forest_dataset_mu_train_.NumCovariates())); + std::vector variable_weights_mu(forest_dataset_mu_train_.NumCovariates(), const_var_wt_mu); + double const_var_wt_tau = static_cast(1/forest_dataset_tau_train_.NumCovariates()); + std::vector variable_weights_tau(forest_dataset_tau_train_.NumCovariates(), const_var_wt_tau); + + // Initialize trackers and tree priors + int n = forest_dataset_mu_train_.NumObservations(); + ForestTracker mu_tracker = ForestTracker(forest_dataset_mu_train_.GetCovariates(), params.feature_types_mu, params.num_trees_mu, n); + ForestTracker tau_tracker = ForestTracker(forest_dataset_tau_train_.GetCovariates(), params.feature_types_tau, params.num_trees_tau, n); + TreePrior tree_prior_mu = TreePrior(params.alpha_mu, params.beta_mu, params.min_samples_leaf_mu); + TreePrior tree_prior_tau = TreePrior(params.alpha_tau, params.beta_tau, params.min_samples_leaf_tau); + + // Initialize leaf values + // TODO: handle multivariate tau case + forest_samples_mu->SetLeafValue(0, params.leaf_init_mu); + forest_samples_tau->SetLeafValue(0, params.leaf_init_tau); + UpdateResidualEntireForest(mu_tracker, forest_dataset_mu_train_, residual_train_, forest_samples_mu->GetEnsemble(0), false, std::minus()); + UpdateResidualEntireForest(tau_tracker, forest_dataset_tau_train_, residual_train_, forest_samples_tau->GetEnsemble(0), true, std::minus()); + + // Variance models (if requested) + GlobalHomoskedasticVarianceModel global_var_model; + LeafNodeHomoskedasticVarianceModel prognostic_leaf_var_model; + LeafNodeHomoskedasticVarianceModel treatment_leaf_var_model; + if (global_var_random_) global_var_model = GlobalHomoskedasticVarianceModel(); + if (prognostic_leaf_var_random_) prognostic_leaf_var_model = LeafNodeHomoskedasticVarianceModel(); + if (treatment_leaf_var_random_) treatment_leaf_var_model = LeafNodeHomoskedasticVarianceModel(); + + // Initial values of (potentially random) parameters + double sigma2 = params.sigma2; + double leaf_scale_mu = params.sigma_leaf_mu; + double leaf_scale_tau = params.sigma_leaf_tau; + Eigen::MatrixXd leaf_scale_tau_mat = params.sigma_leaf_tau_mat; + + if (params.num_gfr > 0) { + // Initialize GFR sampler for mu and tau + GFRForestSampler mu_sampler_gfr = GFRForestSampler(params.cutpoint_grid_size); + GFRForestSampler tau_sampler_gfr = GFRForestSampler(params.cutpoint_grid_size); + + // Run the GFR sampler + for (int i = 0; i < params.num_gfr; i++) { + + // Sample mu ensemble + mu_sampler_gfr.SampleOneIter(mu_tracker, *forest_samples_mu, leaf_model_mu, forest_dataset_mu_train_, residual_train_, tree_prior_mu, *rng, variable_weights_mu, sigma2, params.feature_types_mu, true); + + // Sample leaf node variance + if (prognostic_leaf_var_random_) { + leaf_scale_mu = prognostic_leaf_var_model.SampleVarianceParameter(forest_samples_mu->GetEnsemble(i), params.a_leaf_mu, params.b_leaf_mu, *rng); + prognostic_leaf_var_samples_(i) = leaf_scale_mu; + } + + // Sample global variance + if (global_var_random_) { + sigma2 = global_var_model.SampleVarianceParameter(residual_train_.GetData(), params.nu, params.nu*params.lamb, *rng); + global_var_samples_(i) = sigma2; + } + + // Sample tau ensemble + tau_sampler_gfr.SampleOneIter(tau_tracker, *forest_samples_tau, leaf_model_tau, forest_dataset_tau_train_, residual_train_, tree_prior_tau, *rng, variable_weights_tau, sigma2, params.feature_types_tau, true); + + // Sample leaf node variance + if (treatment_leaf_var_random_) { + leaf_scale_tau = treatment_leaf_var_model.SampleVarianceParameter(forest_samples_tau->GetEnsemble(i), params.a_leaf_tau, params.b_leaf_tau, *rng); + treatment_leaf_var_samples_(i) = leaf_scale_tau; + } + + // Sample global variance + if (global_var_random_) { + sigma2 = global_var_model.SampleVarianceParameter(residual_train_.GetData(), params.nu, params.nu*params.lamb, *rng); + global_var_samples_(i) = sigma2; + } + } + } + + if (params.num_burnin + params.num_mcmc > 0) { + // Initialize GFR sampler for mu and tau + MCMCForestSampler mu_sampler_mcmc = MCMCForestSampler(); + MCMCForestSampler tau_sampler_mcmc = MCMCForestSampler(); + + // Run the GFR sampler + for (int i = params.num_gfr; i < params.num_gfr + params.num_burnin + params.num_mcmc; i++) { + + // Sample mu ensemble + mu_sampler_mcmc.SampleOneIter(mu_tracker, *forest_samples_mu, leaf_model_mu, forest_dataset_mu_train_, residual_train_, tree_prior_mu, *rng, variable_weights_mu, sigma2, true); + + // Sample leaf node variance + if (prognostic_leaf_var_random_) { + leaf_scale_mu = prognostic_leaf_var_model.SampleVarianceParameter(forest_samples_mu->GetEnsemble(i), params.a_leaf_mu, params.b_leaf_mu, *rng); + prognostic_leaf_var_samples_(i) = leaf_scale_mu; + } + + // Sample global variance + if (global_var_random_) { + sigma2 = global_var_model.SampleVarianceParameter(residual_train_.GetData(), params.nu, params.nu*params.lamb, *rng); + global_var_samples_(i) = sigma2; + } + + // Sample tau ensemble + tau_sampler_mcmc.SampleOneIter(tau_tracker, *forest_samples_tau, leaf_model_tau, forest_dataset_tau_train_, residual_train_, tree_prior_tau, *rng, variable_weights_tau, sigma2, true); + + // Sample leaf node variance + if (treatment_leaf_var_random_) { + leaf_scale_tau = treatment_leaf_var_model.SampleVarianceParameter(forest_samples_tau->GetEnsemble(i), params.a_leaf_tau, params.b_leaf_tau, *rng); + treatment_leaf_var_samples_(i) = leaf_scale_tau; + } + + // Sample global variance + if (global_var_random_) { + sigma2 = global_var_model.SampleVarianceParameter(residual_train_.GetData(), params.nu, params.nu*params.lamb, *rng); + global_var_samples_(i) = sigma2; + } + } + } + } +}; + +} // namespace StochTree + +#endif // STOCHTREE_CPP_API_H_ diff --git a/include/stochtree/meta.h b/include/stochtree/meta.h index 5ae002e2..10b96f9f 100644 --- a/include/stochtree/meta.h +++ b/include/stochtree/meta.h @@ -31,6 +31,12 @@ namespace StochTree { +enum ForestLeafModel { + kConstant, + kUnivariateRegression, + kMultivariateRegression +}; + enum FeatureType { kNumeric, kOrderedCategorical, From 16a48cea43b143d0df3e8e899cf1d997950e0bf0 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 29 Apr 2024 13:28:32 -0500 Subject: [PATCH 5/8] Updated bcf C++ demo to store all data in column-major format and fixed symbol name clash in BCFParameters initialization --- debug/bcf_debug.cpp | 36 +++++----- include/stochtree/cpp_api.h | 139 ++++++++++++++++++------------------ 2 files changed, 88 insertions(+), 87 deletions(-) diff --git a/debug/bcf_debug.cpp b/debug/bcf_debug.cpp index 52c44436..77faaf70 100644 --- a/debug/bcf_debug.cpp +++ b/debug/bcf_debug.cpp @@ -51,11 +51,11 @@ double mu1(std::vector& covariates, int n, int x_cols, int i) { CHECK_GE(x_cols, 5); CHECK_GT(n, i); double x1, x2, x3, x4, x5; - x1 = covariates[i*x_cols + 0]; - x2 = covariates[i*x_cols + 1]; - x3 = covariates[i*x_cols + 2]; - x4 = covariates[i*x_cols + 3]; - x5 = covariates[i*x_cols + 4]; + x1 = covariates[n*0 + i]; + x2 = covariates[n*1 + i]; + x3 = covariates[n*2 + i]; + x4 = covariates[n*3 + i]; + x5 = covariates[n*4 + i]; return 1.0 + g(x1,x2,x3,x4,x5) + x1*x3; } @@ -63,11 +63,11 @@ double mu2(std::vector& covariates, int n, int x_cols, int i) { CHECK_GE(x_cols, 5); CHECK_GT(n, i); double x1, x2, x3, x4, x5; - x1 = covariates[i*x_cols + 0]; - x2 = covariates[i*x_cols + 1]; - x3 = covariates[i*x_cols + 2]; - x4 = covariates[i*x_cols + 3]; - x5 = covariates[i*x_cols + 4]; + x1 = covariates[n*0 + i]; + x2 = covariates[n*1 + i]; + x3 = covariates[n*2 + i]; + x4 = covariates[n*3 + i]; + x5 = covariates[n*4 + i]; return 1.0 + g(x1,x2,x3,x4,x5) + 6.0*std::abs(x3-1); } @@ -79,11 +79,11 @@ double tau2(std::vector& covariates, int n, int x_cols, int i) { CHECK_GE(x_cols, 5); CHECK_GT(n, i); double x1, x2, x3, x4, x5; - x1 = covariates[i*x_cols + 0]; - x2 = covariates[i*x_cols + 1]; - x3 = covariates[i*x_cols + 2]; - x4 = covariates[i*x_cols + 3]; - x5 = covariates[i*x_cols + 4]; + x1 = covariates[n*0 + i]; + x2 = covariates[n*1 + i]; + x3 = covariates[n*2 + i]; + x4 = covariates[n*3 + i]; + x5 = covariates[n*4 + i]; return 1 + 2.0*x2*x4; } @@ -232,9 +232,9 @@ void RunAPI() { std::vector covariates_pi(n*(x_cols+1)); for (int i = 0; i < n; i++) { for (int j = 0; j < x_cols; j++) { - covariates_pi[i*(x_cols+1) + j] = covariates_raw[i*x_cols + j]; + covariates_pi[n*j + i] = covariates_raw[n*j + i]; } - covariates_pi[i*(x_cols+1) + x_cols] = propensity_raw[i]; + covariates_pi[n*x_cols + i] = propensity_raw[i]; } // Define internal datasets @@ -282,7 +282,7 @@ void RunAPI() { // ForestLeafModel leaf_model_type_mu = ForestLeafModel::kConstant; // ForestLeafModel leaf_model_type_tau = ForestLeafModel::kUnivariateRegression; int num_gfr_samples = 10; - int num_mcmc_samples = 1000; + int num_mcmc_samples = 10; int num_samples = num_gfr_samples + num_mcmc_samples; // // Set leaf model parameters diff --git a/include/stochtree/cpp_api.h b/include/stochtree/cpp_api.h index 71debc55..57bb05f2 100644 --- a/include/stochtree/cpp_api.h +++ b/include/stochtree/cpp_api.h @@ -20,7 +20,8 @@ namespace StochTree { -struct BCFParameters { +class BCFParameters { + public: int cutpoint_grid_size; double sigma_leaf_mu; double sigma_leaf_tau; @@ -50,75 +51,75 @@ struct BCFParameters { double leaf_init_mu; double leaf_init_tau; - BCFParameters(int cutpoint_grid_size, double sigma_leaf_mu, double sigma_leaf_tau, - double alpha_mu, double alpha_tau, double beta_mu, double beta_tau, - int min_samples_leaf_mu, int min_samples_leaf_tau, double nu, double lamb, - double a_leaf_mu, double a_leaf_tau, double b_leaf_mu, double b_leaf_tau, - double sigma2, int num_trees_mu, int num_trees_tau, double b1, double b0, - std::vector& feature_types_mu, std::vector& feature_types_tau, - int num_gfr, int num_burnin, int num_mcmc, double leaf_init_mu, double leaf_init_tau) { - cutpoint_grid_size = cutpoint_grid_size; - sigma_leaf_mu = sigma_leaf_mu; - sigma_leaf_tau = sigma_leaf_tau; - alpha_mu = alpha_mu; - alpha_tau = alpha_tau; - beta_mu = beta_mu; - beta_tau = beta_tau; - min_samples_leaf_mu = min_samples_leaf_mu; - min_samples_leaf_tau = min_samples_leaf_tau; - nu = nu; - lamb = lamb; - a_leaf_mu = a_leaf_mu; - a_leaf_tau = a_leaf_tau; - b_leaf_mu = b_leaf_mu; - b_leaf_tau = b_leaf_tau; - sigma2 = sigma2; - num_trees_mu = num_trees_mu; - num_trees_tau = num_trees_tau; - b1 = b1; - b0 = b0; - num_gfr = num_gfr; - num_burnin = num_burnin; - num_mcmc = num_mcmc; - feature_types_mu = feature_types_mu; - feature_types_tau = feature_types_tau; - leaf_init_mu = leaf_init_mu; - leaf_init_tau = leaf_init_tau; + BCFParameters(int cutpoint_grid_size_, double sigma_leaf_mu_, double sigma_leaf_tau_, + double alpha_mu_, double alpha_tau_, double beta_mu_, double beta_tau_, + int min_samples_leaf_mu_, int min_samples_leaf_tau_, double nu_, double lamb_, + double a_leaf_mu_, double a_leaf_tau_, double b_leaf_mu_, double b_leaf_tau_, + double sigma2_, int num_trees_mu_, int num_trees_tau_, double b1_, double b0_, + std::vector& feature_types_mu_, std::vector& feature_types_tau_, + int num_gfr_, int num_burnin_, int num_mcmc_, double leaf_init_mu_, double leaf_init_tau_) { + cutpoint_grid_size = cutpoint_grid_size_; + sigma_leaf_mu = sigma_leaf_mu_; + sigma_leaf_tau = sigma_leaf_tau_; + alpha_mu = alpha_mu_; + alpha_tau = alpha_tau_; + beta_mu = beta_mu_; + beta_tau = beta_tau_; + min_samples_leaf_mu = min_samples_leaf_mu_; + min_samples_leaf_tau = min_samples_leaf_tau_; + nu = nu_; + lamb = lamb_; + a_leaf_mu = a_leaf_mu_; + a_leaf_tau = a_leaf_tau_; + b_leaf_mu = b_leaf_mu_; + b_leaf_tau = b_leaf_tau_; + sigma2 = sigma2_; + num_trees_mu = num_trees_mu_; + num_trees_tau = num_trees_tau_; + b1 = b1_; + b0 = b0_; + num_gfr = num_gfr_; + num_burnin = num_burnin_; + num_mcmc = num_mcmc_; + feature_types_mu = feature_types_mu_; + feature_types_tau = feature_types_tau_; + leaf_init_mu = leaf_init_mu_; + leaf_init_tau = leaf_init_tau_; } - BCFParameters(int cutpoint_grid_size, double sigma_leaf_mu, Eigen::MatrixXd& sigma_leaf_tau, - double alpha_mu, double alpha_tau, double beta_mu, double beta_tau, - int min_samples_leaf_mu, int min_samples_leaf_tau, double nu, double lamb, - double a_leaf_mu, double a_leaf_tau, double b_leaf_mu, double b_leaf_tau, - double sigma2, int num_trees_mu, int num_trees_tau, double b1, double b0, - std::vector& feature_types_mu, std::vector& feature_types_tau, - int num_gfr, int num_burnin, int num_mcmc, double leaf_init_mu, double leaf_init_tau) { - cutpoint_grid_size = cutpoint_grid_size; - sigma_leaf_mu = sigma_leaf_mu; - sigma_leaf_tau_mat = sigma_leaf_tau; - alpha_mu = alpha_mu; - alpha_tau = alpha_tau; - beta_mu = beta_mu; - beta_tau = beta_tau; - min_samples_leaf_mu = min_samples_leaf_mu; - min_samples_leaf_tau = min_samples_leaf_tau; - nu = nu; - lamb = lamb; - a_leaf_mu = a_leaf_mu; - a_leaf_tau = a_leaf_tau; - b_leaf_mu = b_leaf_mu; - b_leaf_tau = b_leaf_tau; - sigma2 = sigma2; - num_trees_mu = num_trees_mu; - num_trees_tau = num_trees_tau; - b1 = b1; - b0 = b0; - num_gfr = num_gfr; - num_burnin = num_burnin; - num_mcmc = num_mcmc; - feature_types_mu = feature_types_mu; - feature_types_tau = feature_types_tau; - leaf_init_mu = leaf_init_mu; - leaf_init_tau = leaf_init_tau; + BCFParameters(int cutpoint_grid_size_, double sigma_leaf_mu_, Eigen::MatrixXd& sigma_leaf_tau_, + double alpha_mu_, double alpha_tau_, double beta_mu_, double beta_tau_, + int min_samples_leaf_mu_, int min_samples_leaf_tau_, double nu_, double lamb_, + double a_leaf_mu_, double a_leaf_tau_, double b_leaf_mu_, double b_leaf_tau_, + double sigma2_, int num_trees_mu_, int num_trees_tau_, double b1_, double b0_, + std::vector& feature_types_mu_, std::vector& feature_types_tau_, + int num_gfr_, int num_burnin_, int num_mcmc_, double leaf_init_mu_, double leaf_init_tau_) { + cutpoint_grid_size = cutpoint_grid_size_; + sigma_leaf_mu = sigma_leaf_mu_; + sigma_leaf_tau_mat = sigma_leaf_tau_; + alpha_mu = alpha_mu_; + alpha_tau = alpha_tau_; + beta_mu = beta_mu_; + beta_tau = beta_tau_; + min_samples_leaf_mu = min_samples_leaf_mu_; + min_samples_leaf_tau = min_samples_leaf_tau_; + nu = nu_; + lamb = lamb_; + a_leaf_mu = a_leaf_mu_; + a_leaf_tau = a_leaf_tau_; + b_leaf_mu = b_leaf_mu_; + b_leaf_tau = b_leaf_tau_; + sigma2 = sigma2_; + num_trees_mu = num_trees_mu_; + num_trees_tau = num_trees_tau_; + b1 = b1_; + b0 = b0_; + num_gfr = num_gfr_; + num_burnin = num_burnin_; + num_mcmc = num_mcmc_; + feature_types_mu = feature_types_mu_; + feature_types_tau = feature_types_tau_; + leaf_init_mu = leaf_init_mu_; + leaf_init_tau = leaf_init_tau_; } }; From 6c49b8755df6fa847b4f370060fc416c2ceb68b4 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 29 Apr 2024 19:35:30 -0500 Subject: [PATCH 6/8] Added adaptive coding parameter sampler --- debug/bcf_debug.cpp | 24 ++ include/stochtree/category_tracker.h | 13 + include/stochtree/container.h | 6 + include/stochtree/cpp_api.h | 416 ++++++++++++++++++++++++++- include/stochtree/ensemble.h | 64 ++++- include/stochtree/meta.h | 8 +- src/container.cpp | 72 +++++ 7 files changed, 593 insertions(+), 10 deletions(-) diff --git a/debug/bcf_debug.cpp b/debug/bcf_debug.cpp index 77faaf70..1fa61786 100644 --- a/debug/bcf_debug.cpp +++ b/debug/bcf_debug.cpp @@ -301,6 +301,15 @@ void RunAPI() { // Initialize storage for samples of variance std::vector global_variance_samples(num_samples); std::vector leaf_variance_samples_mu(num_samples); + + // Storage for samples of b1 and b0 + std::vector b1_samples(num_samples); + std::vector b0_samples(num_samples); + + // Storage for samples of muhat, tauhat, and yhat + MatrixObject muhat_samples(n, num_samples); + MatrixObject yhat_samples(n, num_samples); + VectorObject tauhat_samples(n*num_samples*1); // Initialize the BCF sampler BCFModel bcf = BCFModel(); @@ -308,6 +317,9 @@ void RunAPI() { covariates_raw.data(), x_cols, treatment_raw.data(), 1, true); bcf.ResetGlobalVarSamples(global_variance_samples.data(), num_samples); bcf.ResetPrognosticLeafVarSamples(leaf_variance_samples_mu.data(), num_samples); + bcf.ResetTreatedCodingSamples(b1_samples.data(), num_samples); + bcf.ResetControlCodingSamples(b0_samples.data(), num_samples); + bcf.ResetTrainPredictionSamples(muhat_samples.data(), tauhat_samples.data(), yhat_samples.data(), n, num_samples, 1); // Run the BCF sampler bcf.SampleBCF(&forest_samples_mu, &forest_samples_tau, &rng, cutpoint_grid_size, @@ -315,6 +327,18 @@ void RunAPI() { min_samples_leaf_mu, min_samples_leaf_tau, nu, lamb, a_leaf_mu, a_leaf_tau, b_leaf_mu, b_leaf_tau, sigma2, num_trees_mu, num_trees_tau, b1, b0, feature_types_mu, feature_types_tau, num_gfr_samples, 0, num_mcmc_samples, 0.0, 0.0); + + // Analysis predictions + double ssr = 0; + for (int i = 0; i < n; i++) { + double yhat_mean = 0; + for (int j = num_gfr_samples; j < num_samples; j++) { + yhat_mean += yhat_samples(i, j) / (num_samples - num_gfr_samples); + } + ssr += (outcome_raw.at(i) - yhat_mean)*(outcome_raw.at(i) - yhat_mean); + } + double rmse = std::sqrt(ssr / n); + std::cout << "Train set RMSE = " << rmse << std::endl; } } // namespace StochTree diff --git a/include/stochtree/category_tracker.h b/include/stochtree/category_tracker.h index 96ab039a..8423f9a2 100644 --- a/include/stochtree/category_tracker.h +++ b/include/stochtree/category_tracker.h @@ -79,6 +79,7 @@ class SampleCategoryMapper { */ class CategorySampleTracker { public: + CategorySampleTracker() {} CategorySampleTracker(const std::vector& group_indices) { int n = group_indices.size(); indices_ = std::vector(n); @@ -129,6 +130,18 @@ class CategorySampleTracker { return category_length_[category_id_map_[category_id]]; } + std::vector::iterator CategoryBeginIterator(int category_id) { + data_size_t category_begin = CategoryBegin(category_id); + auto begin_iter = indices_.begin(); + return begin_iter + category_begin; + } + + std::vector::iterator CategoryEndIterator(int category_id) { + data_size_t category_end = CategoryEnd(category_id); + auto begin_iter = indices_.begin(); + return begin_iter + category_end; + } + /*! \brief Number of total categories stored */ inline data_size_t NumCategories() {return category_count_;} diff --git a/include/stochtree/container.h b/include/stochtree/container.h index 1af8dd2b..c93a2c97 100644 --- a/include/stochtree/container.h +++ b/include/stochtree/container.h @@ -33,6 +33,12 @@ class ForestContainer { std::vector Predict(ForestDataset& dataset); std::vector PredictRaw(ForestDataset& dataset); std::vector PredictRaw(ForestDataset& dataset, int forest_num); + void PredictRawInplace(ForestDataset& dataset, VectorMap& output); + void PredictRawInplace(ForestDataset& dataset, VectorMap& output, int forest_num); + void PredictRawInplace(ForestDataset& dataset, VectorMap& output, VectorMap& scalar_multiple); + void PredictRawInplace(ForestDataset& dataset, VectorMap& output, VectorMap& scalar_multiple, int forest_num); + void PredictRawInplace(ForestDataset& dataset, VectorMap& output, VectorMap& scalar_multiple_minuend, VectorMap& scalar_multiple_subtrahend); + void PredictRawInplace(ForestDataset& dataset, VectorMap& output, VectorMap& scalar_multiple_minuend, VectorMap& scalar_multiple_subtrahend, int forest_num); inline TreeEnsemble* GetEnsemble(int i) {return forests_[i].get();} inline int32_t NumSamples() {return num_samples_;} diff --git a/include/stochtree/cpp_api.h b/include/stochtree/cpp_api.h index 57bb05f2..2dcb1ebc 100644 --- a/include/stochtree/cpp_api.h +++ b/include/stochtree/cpp_api.h @@ -7,6 +7,7 @@ #define STOCHTREE_CPP_API_H_ #include +#include #include #include #include @@ -210,36 +211,47 @@ class BCFModel { void ResetGlobalVarSamples(double* data_ptr, int num_samples) { new (&global_var_samples_) MatrixMap(data_ptr, num_samples, 1); global_var_random_ = true; + global_var_samples_mapped_ = true; } void ResetPrognosticLeafVarSamples(double* data_ptr, int num_samples) { new (&prognostic_leaf_var_samples_) MatrixMap(data_ptr, num_samples, 1); prognostic_leaf_var_random_ = true; + prognostic_leaf_var_samples_mapped_ = true; } void ResetTreatmentLeafVarSamples(double* data_ptr, int num_samples) { new (&treatment_leaf_var_samples_) MatrixMap(data_ptr, num_samples, 1); treatment_leaf_var_random_ = true; + treatment_leaf_var_samples_mapped_ = true; } void ResetTreatedCodingSamples(double* data_ptr, int num_samples) { new (&b1_samples_) MatrixMap(data_ptr, num_samples, 1); + b1_samples_mapped_ = true; } void ResetControlCodingSamples(double* data_ptr, int num_samples) { new (&b0_samples_) MatrixMap(data_ptr, num_samples, 1); + b0_samples_mapped_ = true; } - void ResetTrainPredictionSamples(double* muhat_data_ptr, double* tauhat_data_ptr, double* yhat_data_ptr, int num_obs, int num_samples) { + void ResetTrainPredictionSamples(double* muhat_data_ptr, double* tauhat_data_ptr, double* yhat_data_ptr, int num_obs, int num_samples, int treatment_dim) { new (&muhat_train_samples_) MatrixMap(muhat_data_ptr, num_obs, num_samples); - new (&tauhat_train_samples_) MatrixMap(tauhat_data_ptr, num_obs, num_samples); + new (&tauhat_train_samples_) VectorMap(tauhat_data_ptr, num_obs*treatment_dim*num_samples); new (&yhat_train_samples_) MatrixMap(yhat_data_ptr, num_obs, num_samples); + muhat_train_samples_mapped_ = true; + tauhat_train_samples_mapped_ = true; + yhat_train_samples_mapped_ = true; } - void ResetTestPredictionSamples(double* muhat_data_ptr, double* tauhat_data_ptr, double* yhat_data_ptr, int num_obs, int num_samples) { + void ResetTestPredictionSamples(double* muhat_data_ptr, double* tauhat_data_ptr, double* yhat_data_ptr, int num_obs, int num_samples, int treatment_dim) { new (&muhat_test_samples_) MatrixMap(muhat_data_ptr, num_obs, num_samples); - new (&tauhat_test_samples_) MatrixMap(tauhat_data_ptr, num_obs, num_samples); + new (&tauhat_test_samples_) VectorMap(tauhat_data_ptr, num_obs*treatment_dim*num_samples); new (&yhat_test_samples_) MatrixMap(yhat_data_ptr, num_obs, num_samples); + muhat_test_samples_mapped_ = true; + tauhat_test_samples_mapped_ = true; + yhat_test_samples_mapped_ = true; } private: @@ -273,6 +285,19 @@ class BCFModel { MatrixMap muhat_test_samples_{NULL,1,1}; MatrixMap tauhat_test_samples_{NULL,1,1}; MatrixMap yhat_test_samples_{NULL,1,1}; + + // Internal details about whether a MatrixMap has been mapped to a data buffer + bool global_var_samples_mapped_{false}; + bool prognostic_leaf_var_samples_mapped_{false}; + bool treatment_leaf_var_samples_mapped_{false}; + bool b1_samples_mapped_{false}; + bool b0_samples_mapped_{false}; + bool muhat_train_samples_mapped_{false}; + bool tauhat_train_samples_mapped_{false}; + bool yhat_train_samples_mapped_{false}; + bool muhat_test_samples_mapped_{false}; + bool tauhat_test_samples_mapped_{false}; + bool yhat_test_samples_mapped_{false}; TauModelType InitializeTauLeafModel(double sigma_leaf, Eigen::MatrixXd& sigma_leaf_mat) { if constexpr (std::is_same_v) { @@ -283,6 +308,11 @@ class BCFModel { } void SampleBCFInternal(ForestContainer* forest_samples_mu, ForestContainer* forest_samples_tau, std::mt19937* rng, BCFParameters& params) { + // Input checks + CHECK(yhat_train_samples_mapped_); + CHECK(muhat_train_samples_mapped_); + CHECK(tauhat_train_samples_mapped_); + // Initialize leaf models for mu and tau forests GaussianConstantLeafModel leaf_model_mu = GaussianConstantLeafModel(params.sigma_leaf_mu); TauModelType leaf_model_tau = InitializeTauLeafModel(params.sigma_leaf_tau, params.sigma_leaf_tau_mat); @@ -306,6 +336,12 @@ class BCFModel { TreePrior tree_prior_mu = TreePrior(params.alpha_mu, params.beta_mu, params.min_samples_leaf_mu); TreePrior tree_prior_tau = TreePrior(params.alpha_tau, params.beta_tau, params.min_samples_leaf_tau); + // Test set details (if a test set is provided) + int n_test; + if (has_test_) { + n_test = forest_dataset_mu_test_.NumObservations(); + } + // Initialize leaf values // TODO: handle multivariate tau case forest_samples_mu->SetLeafValue(0, params.leaf_init_mu); @@ -321,6 +357,75 @@ class BCFModel { if (prognostic_leaf_var_random_) prognostic_leaf_var_model = LeafNodeHomoskedasticVarianceModel(); if (treatment_leaf_var_random_) treatment_leaf_var_model = LeafNodeHomoskedasticVarianceModel(); + // If treatment is binary, update the basis of the tau regression to use (b1*Z + b0*(1-Z)) + double b1 = params.b1; + double b0 = params.b0; + MatrixObject Z_orig(n, treatment_dim_); + MatrixObject Z_adj; + std::vector Z_int(n * treatment_dim_); + if (treatment_binary_) { + for (int i = 0; i < n; i++) { + for (int j = 0; j < treatment_dim_; j++) { + Z_orig(i,j) = forest_dataset_tau_train_.BasisValue(i,j); + } + } + Z_adj = ( + (b1 * Z_orig).array() + + (b0 * (1 - Z_orig.array())) + ); + forest_dataset_tau_train_.UpdateBasis(Z_adj.data(), n, treatment_dim_, false); + for (int i = 0; i < n; i++) { + for (int j = 0; j < treatment_dim_; j++) { + Z_int.at(i + n * j) = static_cast(forest_dataset_tau_train_.BasisValue(i,j)); + } + } + } + + // Do the same for test set + MatrixObject Z_orig_test; + MatrixObject Z_adj_test; + if (has_test_) { + Z_orig_test = MatrixObject(n_test, treatment_dim_); + if (treatment_binary_) { + for (int i = 0; i < n_test; i++) { + for (int j = 0; j < treatment_dim_; j++) { + Z_orig_test(i,j) = forest_dataset_tau_train_.BasisValue(i,j); + } + } + Z_adj_test = ( + (b1 * Z_orig_test).array() + + (b0 * (1 - Z_orig_test.array())) + ); + forest_dataset_tau_test_.UpdateBasis(Z_adj_test.data(), n_test, treatment_dim_, false); + } + } + + // Override adaptive coding if treatment is not binary or treatment is multivariate + if (adaptive_coding_) { + if (!treatment_binary_) { + adaptive_coding_ = false; + } + if (treatment_dim_ > 1) { + adaptive_coding_ = false; + } + } + + // Prepare to run the adaptive coding sampler (if requested) + VectorObject initial_outcome; + VectorObject outcome_minus_mu; + VectorObject muhat; + VectorObject tauhat; + CategorySampleTracker treatment_category_sampler_tracker; + UnivariateNormalSampler adaptive_coding_sampler = UnivariateNormalSampler(); + if (adaptive_coding_) { + initial_outcome.resize(n); + outcome_minus_mu.resize(n); + muhat.resize(n); + tauhat.resize(n * treatment_dim_); + for (int i = 0; i < n; i++) initial_outcome(i) = residual_train_.GetElement(i); + treatment_category_sampler_tracker = CategorySampleTracker(Z_int); + } + // Initial values of (potentially random) parameters double sigma2 = params.sigma2; double leaf_scale_mu = params.sigma_leaf_mu; @@ -338,6 +443,29 @@ class BCFModel { // Sample mu ensemble mu_sampler_gfr.SampleOneIter(mu_tracker, *forest_samples_mu, leaf_model_mu, forest_dataset_mu_train_, residual_train_, tree_prior_mu, *rng, variable_weights_mu, sigma2, params.feature_types_mu, true); + // Predict from the mu ensemble for the train set + for (int obs_idx = 0; obs_idx < n; obs_idx++) { + double mupred = 0.; + for (int tree_idx = 0; tree_idx < forest_samples_mu->NumTrees(); tree_idx++) { + mupred += forest_samples_mu->GetEnsemble(i)->GetTree(tree_idx)->LeafValue(mu_tracker.GetNodeId(obs_idx, tree_idx), 0); + } + muhat_train_samples_(obs_idx, i) = mupred; + if (adaptive_coding_) muhat(obs_idx) = mupred; + } + + // Predict from the mu ensemble for the test set (if provided) + if (has_test_) { + for (int obs_idx = 0; obs_idx < n_test; obs_idx++) { + double mupred = 0.; + for (int tree_idx = 0; tree_idx < forest_samples_mu->NumTrees(); tree_idx++) { + auto &tree = *forest_samples_mu->GetEnsemble(i)->GetTree(tree_idx); + int32_t nidx = EvaluateTree(tree, forest_dataset_mu_test_.GetCovariates(), i); + mupred += tree.LeafValue(nidx, 0); + } + muhat_test_samples_(obs_idx, i) = mupred; + } + } + // Sample leaf node variance if (prognostic_leaf_var_random_) { leaf_scale_mu = prognostic_leaf_var_model.SampleVarianceParameter(forest_samples_mu->GetEnsemble(i), params.a_leaf_mu, params.b_leaf_mu, *rng); @@ -353,6 +481,130 @@ class BCFModel { // Sample tau ensemble tau_sampler_gfr.SampleOneIter(tau_tracker, *forest_samples_tau, leaf_model_tau, forest_dataset_tau_train_, residual_train_, tree_prior_tau, *rng, variable_weights_tau, sigma2, params.feature_types_tau, true); + // Sample adaptive coding parameter for tau (if applicable), before defining the tau(x) prediction for a sample + if (adaptive_coding_) { + for (int obs_idx = 0; obs_idx < n; obs_idx++) { + for (int treatment_dim_idx = 0; treatment_dim_idx < treatment_dim_; treatment_dim_idx++) { + double taupred = 0.; + for (int tree_idx = 0; tree_idx < forest_samples_mu->NumTrees(); tree_idx++) { + taupred += forest_samples_tau->GetEnsemble(i)->GetTree(tree_idx)->LeafValue(tau_tracker.GetNodeId(obs_idx, tree_idx), 0); + } + tauhat(obs_idx + n * treatment_dim_idx) = taupred; + } + } + + // Data checks and partial residualization + CHECK_EQ(muhat.size(), tauhat.size()); + CHECK_EQ(muhat.size(), initial_outcome.size()); + outcome_minus_mu = initial_outcome.array() - muhat.array(); + + // Compute sufficient statistics for control group observations + std::vector::iterator control_begin_iter = treatment_category_sampler_tracker.CategoryBeginIterator(0); + std::vector::iterator control_end_iter = treatment_category_sampler_tracker.CategoryEndIterator(0); + double sum_xx_control = 0.; + double sum_xy_control = 0.; + for (auto obs_it = control_begin_iter; obs_it != control_end_iter; i++) { + auto idx = *obs_it; + sum_xx_control += tauhat(idx)*tauhat(idx); + sum_xy_control += tauhat(idx)*outcome_minus_mu(idx); + } + + // Compute sufficient statistics for treatment group observations + std::vector::iterator treated_begin_iter = treatment_category_sampler_tracker.CategoryBeginIterator(1); + std::vector::iterator treated_end_iter = treatment_category_sampler_tracker.CategoryEndIterator(1); + double sum_xx_treated = 0.; + double sum_xy_treated = 0.; + for (auto obs_it = treated_begin_iter; obs_it != treated_end_iter; i++) { + auto idx = *obs_it; + sum_xx_treated += tauhat(idx)*tauhat(idx); + sum_xy_treated += tauhat(idx)*outcome_minus_mu(idx); + } + + // Compute the posterior mean and variance for treated and control coding parameters + double post_mean_control = (sum_xy_control / (sum_xx_control + 2*sigma2)); + double post_var_control = (sigma2 / (sum_xx_control + 2*sigma2)); + double post_mean_treated = (sum_xy_treated / (sum_xx_treated + 2*sigma2)); + double post_var_treated = (sigma2 / (sum_xx_treated + 2*sigma2)); + + // Sample new adaptive coding parameters + b0 = adaptive_coding_sampler.Sample(post_mean_control, post_var_control, *rng); + b1 = adaptive_coding_sampler.Sample(post_mean_treated, post_var_treated, *rng); + + // Update sample container for b1 and b0 + b0_samples_(i) = b0; + b1_samples_(i) = b1; + + // Update basis used in the leaf regression in the next iteration + Z_adj = ( + (b1 * Z_orig).array() + + (b0 * (1 - Z_orig.array())) + ); + forest_dataset_tau_train_.UpdateBasis(Z_adj.data(), n, treatment_dim_, false); + + // Update basis used in the leaf regression in the next iteration + if (has_test_) { + Z_adj_test = ( + (b1 * Z_orig).array() + + (b0 * (1 - Z_orig.array())) + ); + forest_dataset_tau_test_.UpdateBasis(Z_adj_test.data(), n_test, treatment_dim_, false); + } + + // Update residual and tree predictions with the new basis + for (int obs_idx = 0; obs_idx < n; obs_idx++) { + // double outcome_pred_tau = 0.; + for (int tree_idx = 0; tree_idx < forest_samples_tau->NumTrees(); tree_idx++) { + // Retrieve the "old" prediction for tree_idx back and the residual value to be updated + double prev_resid = residual_train_.GetElement(obs_idx); + double prev_pred = tau_tracker.GetTreeSamplePrediction(obs_idx, tree_idx); + // Compute the new prediction for tree_idx with updated basis + double outcome_pred_tau_tree = 0.; + for (int treatment_dim_idx = 0; treatment_dim_idx < treatment_dim_; treatment_dim_idx++) { + double leaf_val = forest_samples_tau->GetEnsemble(i)->GetTree(tree_idx)->LeafValue(tau_tracker.GetNodeId(obs_idx, tree_idx), treatment_dim_idx); + double basis_val = forest_dataset_tau_train_.BasisValue(obs_idx, treatment_dim_idx); + outcome_pred_tau_tree += leaf_val * basis_val; + } + // Recalculate prediction in tau_tracker + tau_tracker.SetTreeSamplePrediction(obs_idx, tree_idx, outcome_pred_tau_tree); + // outcome_pred_tau += outcome_pred_tau_tree; + // Update residual (adding back "old" prediction and subtracting new prediction) + residual_train_.SetElement(obs_idx, prev_resid + prev_pred + outcome_pred_tau_tree); + } + } + } + + // Predict from the tau ensemble for the train set + for (int obs_idx = 0; obs_idx < n; obs_idx++) { + double outcome_pred_tau = 0.; + for (int treatment_dim_idx = 0; treatment_dim_idx < treatment_dim_; treatment_dim_idx++) { + double taupred = 0.; + for (int tree_idx = 0; tree_idx < forest_samples_tau->NumTrees(); tree_idx++) { + taupred += forest_samples_tau->GetEnsemble(i)->GetTree(tree_idx)->LeafValue(tau_tracker.GetNodeId(obs_idx, tree_idx), treatment_dim_idx); + } + tauhat_train_samples_(obs_idx + treatment_dim_idx*n + i*n*treatment_dim_) = taupred * (b1 - b0); + outcome_pred_tau += taupred * forest_dataset_tau_train_.BasisValue(obs_idx, treatment_dim_idx); + } + yhat_train_samples_(obs_idx, i) = outcome_pred_tau + muhat_train_samples_(obs_idx, i); + } + + // Predict from the tau ensemble for the test set (if provided) + if (has_test_) { + for (int obs_idx = 0; obs_idx < n_test; obs_idx++) { + double outcome_pred_tau = 0.; + for (int treatment_dim_idx = 0; treatment_dim_idx < treatment_dim_; treatment_dim_idx++) { + double taupred = 0.; + for (int tree_idx = 0; tree_idx < forest_samples_tau->NumTrees(); tree_idx++) { + auto &tree = *forest_samples_tau->GetEnsemble(i)->GetTree(tree_idx); + int32_t nidx = EvaluateTree(tree, forest_dataset_tau_test_.GetCovariates(), i); + taupred += tree.LeafValue(nidx, treatment_dim_idx); + } + tauhat_test_samples_(obs_idx + treatment_dim_idx*n + i*n*treatment_dim_) = taupred * (b1 - b0); + outcome_pred_tau += taupred * forest_dataset_tau_test_.BasisValue(obs_idx, treatment_dim_idx); + } + yhat_test_samples_(obs_idx, i) = outcome_pred_tau + muhat_test_samples_(obs_idx, i); + } + } + // Sample leaf node variance if (treatment_leaf_var_random_) { leaf_scale_tau = treatment_leaf_var_model.SampleVarianceParameter(forest_samples_tau->GetEnsemble(i), params.a_leaf_tau, params.b_leaf_tau, *rng); @@ -368,16 +620,39 @@ class BCFModel { } if (params.num_burnin + params.num_mcmc > 0) { - // Initialize GFR sampler for mu and tau + // Initialize MCMC sampler for mu and tau MCMCForestSampler mu_sampler_mcmc = MCMCForestSampler(); MCMCForestSampler tau_sampler_mcmc = MCMCForestSampler(); - // Run the GFR sampler + // Run the MCMC sampler for (int i = params.num_gfr; i < params.num_gfr + params.num_burnin + params.num_mcmc; i++) { // Sample mu ensemble mu_sampler_mcmc.SampleOneIter(mu_tracker, *forest_samples_mu, leaf_model_mu, forest_dataset_mu_train_, residual_train_, tree_prior_mu, *rng, variable_weights_mu, sigma2, true); + // Predict from the mu ensemble for the train set + for (int obs_idx = 0; obs_idx < n; obs_idx++) { + double mupred = 0.; + for (int tree_idx = 0; tree_idx < forest_samples_mu->NumTrees(); tree_idx++) { + mupred += forest_samples_mu->GetEnsemble(i)->GetTree(tree_idx)->LeafValue(mu_tracker.GetNodeId(obs_idx, tree_idx), 0); + } + muhat_train_samples_(obs_idx, i) = mupred; + if (adaptive_coding_) muhat(obs_idx) = mupred; + } + + // Predict from the mu ensemble for the test set (if provided) + if (has_test_) { + for (int obs_idx = 0; obs_idx < n_test; obs_idx++) { + double mupred = 0.; + for (int tree_idx = 0; tree_idx < forest_samples_mu->NumTrees(); tree_idx++) { + auto &tree = *forest_samples_mu->GetEnsemble(i)->GetTree(tree_idx); + int32_t nidx = EvaluateTree(tree, forest_dataset_mu_test_.GetCovariates(), i); + mupred += tree.LeafValue(nidx, 0); + } + muhat_test_samples_(obs_idx, i) = mupred; + } + } + // Sample leaf node variance if (prognostic_leaf_var_random_) { leaf_scale_mu = prognostic_leaf_var_model.SampleVarianceParameter(forest_samples_mu->GetEnsemble(i), params.a_leaf_mu, params.b_leaf_mu, *rng); @@ -393,6 +668,130 @@ class BCFModel { // Sample tau ensemble tau_sampler_mcmc.SampleOneIter(tau_tracker, *forest_samples_tau, leaf_model_tau, forest_dataset_tau_train_, residual_train_, tree_prior_tau, *rng, variable_weights_tau, sigma2, true); + // Sample adaptive coding parameter for tau (if applicable), before defining the tau(x) prediction for a sample + if (adaptive_coding_) { + for (int obs_idx = 0; obs_idx < n; obs_idx++) { + for (int treatment_dim_idx = 0; treatment_dim_idx < treatment_dim_; treatment_dim_idx++) { + double taupred = 0.; + for (int tree_idx = 0; tree_idx < forest_samples_mu->NumTrees(); tree_idx++) { + taupred += forest_samples_tau->GetEnsemble(i)->GetTree(tree_idx)->LeafValue(tau_tracker.GetNodeId(obs_idx, tree_idx), 0); + } + tauhat(obs_idx + n * treatment_dim_idx) = taupred; + } + } + + // Data checks and partial residualization + CHECK_EQ(muhat.size(), tauhat.size()); + CHECK_EQ(muhat.size(), initial_outcome.size()); + outcome_minus_mu = initial_outcome.array() - muhat.array(); + + // Compute sufficient statistics for control group observations + std::vector::iterator control_begin_iter = treatment_category_sampler_tracker.CategoryBeginIterator(0); + std::vector::iterator control_end_iter = treatment_category_sampler_tracker.CategoryEndIterator(0); + double sum_xx_control = 0.; + double sum_xy_control = 0.; + for (auto obs_it = control_begin_iter; obs_it != control_end_iter; i++) { + auto idx = *obs_it; + sum_xx_control += tauhat(idx)*tauhat(idx); + sum_xy_control += tauhat(idx)*outcome_minus_mu(idx); + } + + // Compute sufficient statistics for treatment group observations + std::vector::iterator treated_begin_iter = treatment_category_sampler_tracker.CategoryBeginIterator(1); + std::vector::iterator treated_end_iter = treatment_category_sampler_tracker.CategoryEndIterator(1); + double sum_xx_treated = 0.; + double sum_xy_treated = 0.; + for (auto obs_it = treated_begin_iter; obs_it != treated_end_iter; i++) { + auto idx = *obs_it; + sum_xx_treated += tauhat(idx)*tauhat(idx); + sum_xy_treated += tauhat(idx)*outcome_minus_mu(idx); + } + + // Compute the posterior mean and variance for treated and control coding parameters + double post_mean_control = (sum_xy_control / (sum_xx_control + 2*sigma2)); + double post_var_control = (sigma2 / (sum_xx_control + 2*sigma2)); + double post_mean_treated = (sum_xy_treated / (sum_xx_treated + 2*sigma2)); + double post_var_treated = (sigma2 / (sum_xx_treated + 2*sigma2)); + + // Sample new adaptive coding parameters + b0 = adaptive_coding_sampler.Sample(post_mean_control, post_var_control, *rng); + b1 = adaptive_coding_sampler.Sample(post_mean_treated, post_var_treated, *rng); + + // Update sample container for b1 and b0 + b0_samples_(i) = b0; + b1_samples_(i) = b1; + + // Update basis used in the leaf regression in the next iteration + Z_adj = ( + (b1 * Z_orig).array() + + (b0 * (1 - Z_orig.array())) + ); + forest_dataset_tau_train_.UpdateBasis(Z_adj.data(), n, treatment_dim_, false); + + // Update basis used in the leaf regression in the next iteration + if (has_test_) { + Z_adj_test = ( + (b1 * Z_orig).array() + + (b0 * (1 - Z_orig.array())) + ); + forest_dataset_tau_test_.UpdateBasis(Z_adj_test.data(), n_test, treatment_dim_, false); + } + + // Update residual and tree predictions with the new basis + for (int obs_idx = 0; obs_idx < n; obs_idx++) { + // double outcome_pred_tau = 0.; + for (int tree_idx = 0; tree_idx < forest_samples_tau->NumTrees(); tree_idx++) { + // Retrieve the "old" prediction for tree_idx back and the residual value to be updated + double prev_resid = residual_train_.GetElement(obs_idx); + double prev_pred = tau_tracker.GetTreeSamplePrediction(obs_idx, tree_idx); + // Compute the new prediction for tree_idx with updated basis + double outcome_pred_tau_tree = 0.; + for (int treatment_dim_idx = 0; treatment_dim_idx < treatment_dim_; treatment_dim_idx++) { + double leaf_val = forest_samples_tau->GetEnsemble(i)->GetTree(tree_idx)->LeafValue(tau_tracker.GetNodeId(obs_idx, tree_idx), treatment_dim_idx); + double basis_val = forest_dataset_tau_train_.BasisValue(obs_idx, treatment_dim_idx); + outcome_pred_tau_tree += leaf_val * basis_val; + } + // Recalculate prediction in tau_tracker + tau_tracker.SetTreeSamplePrediction(obs_idx, tree_idx, outcome_pred_tau_tree); + // outcome_pred_tau += outcome_pred_tau_tree; + // Update residual (adding back "old" prediction and subtracting new prediction) + residual_train_.SetElement(obs_idx, prev_resid + prev_pred + outcome_pred_tau_tree); + } + } + } + + // Predict from the tau ensemble for the train set + for (int obs_idx = 0; obs_idx < n; obs_idx++) { + double outcome_pred_tau = 0.; + for (int treatment_dim_idx = 0; treatment_dim_idx < treatment_dim_; treatment_dim_idx++) { + double taupred = 0.; + for (int tree_idx = 0; tree_idx < forest_samples_tau->NumTrees(); tree_idx++) { + taupred += forest_samples_tau->GetEnsemble(i)->GetTree(tree_idx)->LeafValue(tau_tracker.GetNodeId(obs_idx, tree_idx), treatment_dim_idx); + } + tauhat_train_samples_(obs_idx + treatment_dim_idx*n + i*n*treatment_dim_) = taupred * (b1 - b0); + outcome_pred_tau += taupred * forest_dataset_tau_train_.BasisValue(obs_idx, treatment_dim_idx); + } + yhat_train_samples_(obs_idx, i) = outcome_pred_tau + muhat_train_samples_(obs_idx, i); + } + + // Predict from the tau ensemble for the test set (if provided) + if (has_test_) { + for (int obs_idx = 0; obs_idx < n_test; obs_idx++) { + double outcome_pred_tau = 0.; + for (int treatment_dim_idx = 0; treatment_dim_idx < treatment_dim_; treatment_dim_idx++) { + double taupred = 0.; + for (int tree_idx = 0; tree_idx < forest_samples_tau->NumTrees(); tree_idx++) { + auto &tree = *forest_samples_tau->GetEnsemble(i)->GetTree(tree_idx); + int32_t nidx = EvaluateTree(tree, forest_dataset_tau_test_.GetCovariates(), i); + taupred += tree.LeafValue(nidx, treatment_dim_idx); + } + tauhat_test_samples_(obs_idx + treatment_dim_idx*n + i*n*treatment_dim_) = taupred * (b1 - b0); + outcome_pred_tau += taupred * forest_dataset_tau_test_.BasisValue(obs_idx, treatment_dim_idx); + } + yhat_test_samples_(obs_idx, i) = outcome_pred_tau + muhat_test_samples_(obs_idx, i); + } + } + // Sample leaf node variance if (treatment_leaf_var_random_) { leaf_scale_tau = treatment_leaf_var_model.SampleVarianceParameter(forest_samples_tau->GetEnsemble(i), params.a_leaf_tau, params.b_leaf_tau, *rng); @@ -406,7 +805,10 @@ class BCFModel { } } } - } + + // Predict from each forest + + } }; } // namespace StochTree diff --git a/include/stochtree/ensemble.h b/include/stochtree/ensemble.h index afbd4c79..057c4f46 100644 --- a/include/stochtree/ensemble.h +++ b/include/stochtree/ensemble.h @@ -145,7 +145,7 @@ class TreeEnsemble { } inline void PredictRawInplace(ForestDataset& dataset, std::vector &output, - int tree_begin, int tree_end, data_size_t offset = 0) { + int tree_begin, int tree_end, data_size_t offset = 0) { double pred; MatrixMap covariates = dataset.GetCovariates(); CHECK_EQ(output_dimension_, trees_[0]->OutputDimension()); @@ -167,6 +167,68 @@ class TreeEnsemble { } } + /*! \brief Predict the "raw" output of an ensemble in column-major format */ + inline void PredictRawInplace(ForestDataset& dataset, VectorMap& output, data_size_t offset = 0) { + PredictRawInplace(dataset, output, 0, trees_.size(), offset); + } + + /*! \brief Predict the "raw" output of an ensemble in column-major format beginning with `tree_begin` and ending with `tree_end` */ + inline void PredictRawInplace(ForestDataset& dataset, VectorMap& output, + int tree_begin, int tree_end, data_size_t offset = 0) { + double pred; + MatrixMap covariates = dataset.GetCovariates(); + CHECK_EQ(output_dimension_, trees_[0]->OutputDimension()); + data_size_t n = covariates.rows(); + data_size_t total_output_size = n * output_dimension_; + if (output.size() < total_output_size + offset) { + Log::Fatal("Mismatched size of raw prediction vector and training data"); + } + for (data_size_t i = 0; i < n; i++) { + for (int32_t k = 0; k < output_dimension_; k++) { + pred = 0.0; + for (size_t j = tree_begin; j < tree_end; j++) { + auto &tree = *trees_[j]; + int32_t nidx = EvaluateTree(tree, covariates, i); + pred += tree.LeafValue(nidx, k); + } + output(offset + n*k + i) = pred; + } + } + } + + /*! \brief Predict the "raw" output of an ensemble in column-major format, multiplying every observation by a scalar_multiple + * which is assumed to be constant over all observations + */ + inline void PredictRawInplace(ForestDataset& dataset, VectorMap& output, double scalar_multiple, data_size_t offset = 0) { + PredictRawInplace(dataset, output, scalar_multiple, 0, trees_.size(), offset); + } + + /*! \brief Predict the "raw" output of an ensemble in column-major format beginning with `tree_begin` and ending with `tree_end`, + * multiplying every observation by a scalar multiple which is assumed to be constant over all observations + */ + inline void PredictRawInplace(ForestDataset& dataset, VectorMap& output, double scalar_multiple, + int tree_begin, int tree_end, data_size_t offset = 0) { + double pred; + MatrixMap covariates = dataset.GetCovariates(); + CHECK_EQ(output_dimension_, trees_[0]->OutputDimension()); + data_size_t n = covariates.rows(); + data_size_t total_output_size = n * output_dimension_; + if (output.size() < total_output_size + offset) { + Log::Fatal("Mismatched size of raw prediction vector and training data"); + } + for (data_size_t i = 0; i < n; i++) { + for (int32_t k = 0; k < output_dimension_; k++) { + pred = 0.0; + for (size_t j = tree_begin; j < tree_end; j++) { + auto &tree = *trees_[j]; + int32_t nidx = EvaluateTree(tree, covariates, i); + pred += tree.LeafValue(nidx, k); + } + output(offset + n*k + i) = pred * scalar_multiple; + } + } + } + inline int32_t NumTrees() { return num_trees_; } diff --git a/include/stochtree/meta.h b/include/stochtree/meta.h index 10b96f9f..f17ff25e 100644 --- a/include/stochtree/meta.h +++ b/include/stochtree/meta.h @@ -73,10 +73,14 @@ enum RandomEffectsType { /*! \brief Eigen Map objects that expose matrix / vector operations directly on raw buffers without copying data */ typedef Eigen::Matrix MatrixObject; typedef Eigen::Matrix VectorObject; +typedef Eigen::Matrix IntMatrixObject; +typedef Eigen::Matrix IntVectorObject; /*! \brief Eigen Map objects that expose matrix / vector operations directly on raw buffers without copying data */ -typedef Eigen::Map> MatrixMap; -typedef Eigen::Map> VectorMap; +typedef Eigen::Map MatrixMap; +typedef Eigen::Map VectorMap; +typedef Eigen::Map IntMatrixMap; +typedef Eigen::Map IntVectorMap; /*! \brief Type of data size */ typedef int32_t data_size_t; diff --git a/src/container.cpp b/src/container.cpp index 747e6995..4ee7e810 100644 --- a/src/container.cpp +++ b/src/container.cpp @@ -100,6 +100,78 @@ std::vector ForestContainer::PredictRaw(ForestDataset& dataset, int fore return output; } +void ForestContainer::PredictRawInplace(ForestDataset& dataset, VectorMap& output) { + data_size_t n = dataset.NumObservations(); + // data_size_t total_output_size = n * output_dimension_ * num_samples_; + // std::vector output(total_output_size); + data_size_t offset = 0; + for (int i = 0; i < num_samples_; i++) { + auto num_trees = forests_[i]->NumTrees(); + forests_[i]->PredictRawInplace(dataset, output, 0, num_trees, offset); + offset += n * output_dimension_; + } +} + +void ForestContainer::PredictRawInplace(ForestDataset& dataset, VectorMap& output, int forest_num) { + data_size_t n = dataset.NumObservations(); + // data_size_t total_output_size = n * output_dimension_; + // std::vector output(total_output_size); + data_size_t offset = 0; + auto num_trees = forests_[forest_num]->NumTrees(); + forests_[forest_num]->PredictRawInplace(dataset, output, 0, num_trees, offset); +} + +void ForestContainer::PredictRawInplace(ForestDataset& dataset, VectorMap& output, VectorMap& scalar_multiple) { + CHECK_EQ(scalar_multiple.size(), num_samples_); + data_size_t n = dataset.NumObservations(); + // data_size_t total_output_size = n * output_dimension_ * num_samples_; + // std::vector output(total_output_size); + data_size_t offset = 0; + for (int i = 0; i < num_samples_; i++) { + auto num_trees = forests_[i]->NumTrees(); + forests_[i]->PredictRawInplace(dataset, output, scalar_multiple(i), 0, num_trees, offset); + offset += n * output_dimension_; + } +} + +void ForestContainer::PredictRawInplace(ForestDataset& dataset, VectorMap& output, VectorMap& scalar_multiple, int forest_num) { + CHECK_EQ(scalar_multiple.size(), num_samples_); + data_size_t n = dataset.NumObservations(); + // data_size_t total_output_size = n * output_dimension_; + // std::vector output(total_output_size); + data_size_t offset = 0; + auto num_trees = forests_[forest_num]->NumTrees(); + forests_[forest_num]->PredictRawInplace(dataset, output, scalar_multiple(forest_num), 0, num_trees, offset); +} + +void ForestContainer::PredictRawInplace(ForestDataset& dataset, VectorMap& output, VectorMap& scalar_multiple_minuend, VectorMap& scalar_multiple_subtrahend) { + CHECK_EQ(scalar_multiple_minuend.size(), num_samples_); + CHECK_EQ(scalar_multiple_subtrahend.size(), num_samples_); + data_size_t n = dataset.NumObservations(); + // data_size_t total_output_size = n * output_dimension_ * num_samples_; + // std::vector output(total_output_size); + data_size_t offset = 0; + double mult; + for (int i = 0; i < num_samples_; i++) { + mult = scalar_multiple_minuend(i) - scalar_multiple_subtrahend(i); + auto num_trees = forests_[i]->NumTrees(); + forests_[i]->PredictRawInplace(dataset, output, mult, 0, num_trees, offset); + offset += n * output_dimension_; + } +} + +void ForestContainer::PredictRawInplace(ForestDataset& dataset, VectorMap& output, VectorMap& scalar_multiple_minuend, VectorMap& scalar_multiple_subtrahend, int forest_num) { + CHECK_EQ(scalar_multiple_minuend.size(), num_samples_); + CHECK_EQ(scalar_multiple_subtrahend.size(), num_samples_); + data_size_t n = dataset.NumObservations(); + // data_size_t total_output_size = n * output_dimension_; + // std::vector output(total_output_size); + data_size_t offset = 0; + double mult = scalar_multiple_minuend(forest_num) - scalar_multiple_subtrahend(forest_num); + auto num_trees = forests_[forest_num]->NumTrees(); + forests_[forest_num]->PredictRawInplace(dataset, output, mult, 0, num_trees, offset); +} + /*! \brief Save to JSON */ json ForestContainer::to_json() { json result_obj; From 798d83754c5a5d2099bd7552b128963886df1a59 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Tue, 30 Apr 2024 00:57:55 -0500 Subject: [PATCH 7/8] Template-free wrapper of BCF Model (for use as a type in the R / Python interface) --- debug/bcf_debug.cpp | 3 +- include/stochtree/cpp_api.h | 169 ++++++++++++++++++++++++++++++++++++ 2 files changed, 171 insertions(+), 1 deletion(-) diff --git a/debug/bcf_debug.cpp b/debug/bcf_debug.cpp index 1fa61786..2bcf7846 100644 --- a/debug/bcf_debug.cpp +++ b/debug/bcf_debug.cpp @@ -312,7 +312,8 @@ void RunAPI() { VectorObject tauhat_samples(n*num_samples*1); // Initialize the BCF sampler - BCFModel bcf = BCFModel(); + // BCFModel bcf = BCFModel(); + BCFModelWrapper bcf = BCFModelWrapper(true); bcf.LoadTrain(outcome_raw.data(), n, covariates_pi.data(), x_cols+1, covariates_raw.data(), x_cols, treatment_raw.data(), 1, true); bcf.ResetGlobalVarSamples(global_variance_samples.data(), num_samples); diff --git a/include/stochtree/cpp_api.h b/include/stochtree/cpp_api.h index 2dcb1ebc..2302a988 100644 --- a/include/stochtree/cpp_api.h +++ b/include/stochtree/cpp_api.h @@ -811,6 +811,175 @@ class BCFModel { } }; +class BCFModelWrapper { + private: + bool univariate_{true}; + BCFModel bcf_univariate_; + BCFModel bcf_multivariate_; + public: + BCFModelWrapper(bool univariate = true){ +// BCFModel bcf_univariate_{}; +// BCFModel bcf_multivariate_{}; +// bcf_univariate_ = BCFModel(); +// bcf_multivariate_ = BCFModel(); + if (univariate) univariate_ = true; + else univariate_ = false; + } + ~BCFModelWrapper(){} + void SampleBCF(ForestContainer* forest_samples_mu, ForestContainer* forest_samples_tau, std::mt19937* rng, + int cutpoint_grid_size, double sigma_leaf_mu, double sigma_leaf_tau, + double alpha_mu, double alpha_tau, double beta_mu, double beta_tau, + int min_samples_leaf_mu, int min_samples_leaf_tau, double nu, double lamb, + double a_leaf_mu, double a_leaf_tau, double b_leaf_mu, double b_leaf_tau, + double sigma2, int num_trees_mu, int num_trees_tau, double b1, double b0, + std::vector& feature_types_mu, std::vector& feature_types_tau, + int num_gfr, int num_burnin, int num_mcmc, double leaf_init_mu, double leaf_init_tau) { + if (univariate_) { + bcf_univariate_.SampleBCF( + forest_samples_mu, forest_samples_tau, rng, cutpoint_grid_size, sigma_leaf_mu, sigma_leaf_tau, + alpha_mu, alpha_tau, beta_mu, beta_tau, min_samples_leaf_mu, min_samples_leaf_tau, nu, lamb, + a_leaf_mu, a_leaf_tau, b_leaf_mu, b_leaf_tau, sigma2, num_trees_mu, num_trees_tau, b1, b0, + feature_types_mu, feature_types_tau, num_gfr, num_burnin, num_mcmc, leaf_init_mu, leaf_init_tau + ); + } else { + bcf_multivariate_.SampleBCF( + forest_samples_mu, forest_samples_tau, rng, cutpoint_grid_size, sigma_leaf_mu, sigma_leaf_tau, + alpha_mu, alpha_tau, beta_mu, beta_tau, min_samples_leaf_mu, min_samples_leaf_tau, nu, lamb, + a_leaf_mu, a_leaf_tau, b_leaf_mu, b_leaf_tau, sigma2, num_trees_mu, num_trees_tau, b1, b0, + feature_types_mu, feature_types_tau, num_gfr, num_burnin, num_mcmc, leaf_init_mu, leaf_init_tau + ); + } + } + + void SampleBCF(ForestContainer* forest_samples_mu, ForestContainer* forest_samples_tau, std::mt19937* rng, + int cutpoint_grid_size, double sigma_leaf_mu, Eigen::MatrixXd& sigma_leaf_tau, + double alpha_mu, double alpha_tau, double beta_mu, double beta_tau, + int min_samples_leaf_mu, int min_samples_leaf_tau, double nu, double lamb, + double a_leaf_mu, double a_leaf_tau, double b_leaf_mu, double b_leaf_tau, + double sigma2, int num_trees_mu, int num_trees_tau, double b1, double b0, + std::vector& feature_types_mu, std::vector& feature_types_tau, + int num_gfr, int num_burnin, int num_mcmc, double leaf_init_mu, double leaf_init_tau) { + if (univariate_) { + bcf_univariate_.SampleBCF( + forest_samples_mu, forest_samples_tau, rng, cutpoint_grid_size, sigma_leaf_mu, sigma_leaf_tau, + alpha_mu, alpha_tau, beta_mu, beta_tau, min_samples_leaf_mu, min_samples_leaf_tau, nu, lamb, + a_leaf_mu, a_leaf_tau, b_leaf_mu, b_leaf_tau, sigma2, num_trees_mu, num_trees_tau, b1, b0, + feature_types_mu, feature_types_tau, num_gfr, num_burnin, num_mcmc, leaf_init_mu, leaf_init_tau + ); + } else { + bcf_multivariate_.SampleBCF( + forest_samples_mu, forest_samples_tau, rng, cutpoint_grid_size, sigma_leaf_mu, sigma_leaf_tau, + alpha_mu, alpha_tau, beta_mu, beta_tau, min_samples_leaf_mu, min_samples_leaf_tau, nu, lamb, + a_leaf_mu, a_leaf_tau, b_leaf_mu, b_leaf_tau, sigma2, num_trees_mu, num_trees_tau, b1, b0, + feature_types_mu, feature_types_tau, num_gfr, num_burnin, num_mcmc, leaf_init_mu, leaf_init_tau + ); + } + } + + void LoadTrain(double* residual_data_ptr, int num_rows, double* prognostic_covariate_data_ptr, int num_prognostic_covariates, + double* treatment_covariate_data_ptr, int num_treatment_covariates, double* treatment_data_ptr, + int num_treatment, bool treatment_binary) { + if (univariate_) { + bcf_univariate_.LoadTrain( + residual_data_ptr, num_rows, prognostic_covariate_data_ptr, num_prognostic_covariates, + treatment_covariate_data_ptr, num_treatment_covariates, treatment_data_ptr, num_treatment, treatment_binary + ); + } else { + bcf_multivariate_.LoadTrain( + residual_data_ptr, num_rows, prognostic_covariate_data_ptr, num_prognostic_covariates, + treatment_covariate_data_ptr, num_treatment_covariates, treatment_data_ptr, num_treatment, treatment_binary + ); + } + } + + void LoadTrain(double* residual_data_ptr, int num_rows, double* prognostic_covariate_data_ptr, int num_prognostic_covariates, + double* treatment_covariate_data_ptr, int num_treatment_covariates, double* treatment_data_ptr, + int num_treatment, bool treatment_binary, double* weights_data_ptr) { + if (univariate_) { + bcf_univariate_.LoadTrain( + residual_data_ptr, num_rows, prognostic_covariate_data_ptr, num_prognostic_covariates, + treatment_covariate_data_ptr, num_treatment_covariates, treatment_data_ptr, num_treatment, treatment_binary, weights_data_ptr + ); + } else { + bcf_multivariate_.LoadTrain( + residual_data_ptr, num_rows, prognostic_covariate_data_ptr, num_prognostic_covariates, + treatment_covariate_data_ptr, num_treatment_covariates, treatment_data_ptr, num_treatment, treatment_binary, weights_data_ptr + ); + } + } + + void LoadTest(double* prognostic_covariate_data_ptr, int num_rows, int num_prognostic_covariates, + double* treatment_covariate_data_ptr, int num_treatment_covariates, double* treatment_data_ptr, int num_treatment) { + if (univariate_) { + bcf_univariate_.LoadTest( + prognostic_covariate_data_ptr, num_rows, num_prognostic_covariates, treatment_covariate_data_ptr, + num_treatment_covariates, treatment_data_ptr, num_treatment + ); + } else { + bcf_multivariate_.LoadTest( + prognostic_covariate_data_ptr, num_rows, num_prognostic_covariates, treatment_covariate_data_ptr, + num_treatment_covariates, treatment_data_ptr, num_treatment + ); + } + } + + void ResetGlobalVarSamples(double* data_ptr, int num_samples) { + if (univariate_) { + bcf_univariate_.ResetGlobalVarSamples(data_ptr, num_samples); + } else { + bcf_multivariate_.ResetGlobalVarSamples(data_ptr, num_samples); + } + } + + void ResetPrognosticLeafVarSamples(double* data_ptr, int num_samples) { + if (univariate_) { + bcf_univariate_.ResetPrognosticLeafVarSamples(data_ptr, num_samples); + } else { + bcf_multivariate_.ResetPrognosticLeafVarSamples(data_ptr, num_samples); + } + } + + void ResetTreatmentLeafVarSamples(double* data_ptr, int num_samples) { + if (univariate_) { + bcf_univariate_.ResetTreatmentLeafVarSamples(data_ptr, num_samples); + } else { + bcf_multivariate_.ResetTreatmentLeafVarSamples(data_ptr, num_samples); + } + } + + void ResetTreatedCodingSamples(double* data_ptr, int num_samples) { + if (univariate_) { + bcf_univariate_.ResetTreatedCodingSamples(data_ptr, num_samples); + } else { + bcf_multivariate_.ResetTreatedCodingSamples(data_ptr, num_samples); + } + } + + void ResetControlCodingSamples(double* data_ptr, int num_samples) { + if (univariate_) { + bcf_univariate_.ResetControlCodingSamples(data_ptr, num_samples); + } else { + bcf_multivariate_.ResetControlCodingSamples(data_ptr, num_samples); + } + } + + void ResetTrainPredictionSamples(double* muhat_data_ptr, double* tauhat_data_ptr, double* yhat_data_ptr, int num_obs, int num_samples, int treatment_dim) { + if (univariate_) { + bcf_univariate_.ResetTrainPredictionSamples(muhat_data_ptr, tauhat_data_ptr, yhat_data_ptr, num_obs, num_samples, treatment_dim); + } else { + bcf_multivariate_.ResetTrainPredictionSamples(muhat_data_ptr, tauhat_data_ptr, yhat_data_ptr, num_obs, num_samples, treatment_dim); + } + } + + void ResetTestPredictionSamples(double* muhat_data_ptr, double* tauhat_data_ptr, double* yhat_data_ptr, int num_obs, int num_samples, int treatment_dim) { + if (univariate_) { + bcf_univariate_.ResetTestPredictionSamples(muhat_data_ptr, tauhat_data_ptr, yhat_data_ptr, num_obs, num_samples, treatment_dim); + } else { + bcf_multivariate_.ResetTestPredictionSamples(muhat_data_ptr, tauhat_data_ptr, yhat_data_ptr, num_obs, num_samples, treatment_dim); + } + } +}; + } // namespace StochTree #endif // STOCHTREE_CPP_API_H_ From a0a53adac5809a529f0cefc88c04a047e08f8bb3 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Tue, 30 Apr 2024 02:55:42 -0500 Subject: [PATCH 8/8] Removed commented lines --- include/stochtree/cpp_api.h | 4 ---- 1 file changed, 4 deletions(-) diff --git a/include/stochtree/cpp_api.h b/include/stochtree/cpp_api.h index 2302a988..297859ad 100644 --- a/include/stochtree/cpp_api.h +++ b/include/stochtree/cpp_api.h @@ -818,10 +818,6 @@ class BCFModelWrapper { BCFModel bcf_multivariate_; public: BCFModelWrapper(bool univariate = true){ -// BCFModel bcf_univariate_{}; -// BCFModel bcf_multivariate_{}; -// bcf_univariate_ = BCFModel(); -// bcf_multivariate_ = BCFModel(); if (univariate) univariate_ = true; else univariate_ = false; }