Skip to content

Commit b3376bf

Browse files
authored
Merge pull request #161 from StochasticTree/numpy-dtype-hotfix
Specify float dtype in numpy ones and zeros
2 parents 7abd34d + 5db6888 commit b3376bf

File tree

3 files changed

+9
-9
lines changed

3 files changed

+9
-9
lines changed

stochtree/bart.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -734,13 +734,13 @@ def sample(
734734
)
735735
if self.has_basis:
736736
if sigma_leaf is None:
737-
current_leaf_scale = np.zeros((self.num_basis, self.num_basis))
737+
current_leaf_scale = np.zeros((self.num_basis, self.num_basis), dtype=float)
738738
np.fill_diagonal(
739739
current_leaf_scale,
740740
np.squeeze(np.var(resid_train)) / num_trees_mean,
741741
)
742742
elif isinstance(sigma_leaf, float):
743-
current_leaf_scale = np.zeros((self.num_basis, self.num_basis))
743+
current_leaf_scale = np.zeros((self.num_basis, self.num_basis), dtype=float)
744744
np.fill_diagonal(current_leaf_scale, sigma_leaf)
745745
elif isinstance(sigma_leaf, np.ndarray):
746746
if sigma_leaf.ndim != 2:
@@ -834,7 +834,7 @@ def sample(
834834
alpha_init = np.array([1])
835835
elif num_rfx_components > 1:
836836
alpha_init = np.concatenate(
837-
(np.ones(1), np.zeros(num_rfx_components - 1))
837+
(np.ones(1, dtype=float), np.zeros(num_rfx_components - 1, dtype=float))
838838
)
839839
else:
840840
raise ValueError("There must be at least 1 random effect component")

stochtree/bcf.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1163,7 +1163,7 @@ def sample(
11631163
raise ValueError("sigma_leaf_mu must be a scalar")
11641164
if isinstance(sigma_leaf_tau, float):
11651165
if Z_train.shape[1] > 1:
1166-
current_leaf_scale_tau = np.zeros((Z_train.shape[1], Z_train.shape[1]))
1166+
current_leaf_scale_tau = np.zeros((Z_train.shape[1], Z_train.shape[1]), dtype=float)
11671167
np.fill_diagonal(current_leaf_scale_tau, sigma_leaf_tau)
11681168
else:
11691169
current_leaf_scale_tau = np.array([[sigma_leaf_tau]])
@@ -1230,7 +1230,7 @@ def sample(
12301230
alpha_init = np.array([1])
12311231
elif num_rfx_components > 1:
12321232
alpha_init = np.concatenate(
1233-
(np.ones(1), np.zeros(num_rfx_components - 1))
1233+
(np.ones(1, dtype=float), np.zeros(num_rfx_components - 1, dtype=float))
12341234
)
12351235
else:
12361236
raise ValueError("There must be at least 1 random effect component")
@@ -1488,7 +1488,7 @@ def sample(
14881488

14891489
# Initialize the leaves of each tree in the treatment forest
14901490
if self.multivariate_treatment:
1491-
init_tau = np.zeros(Z_train.shape[1])
1491+
init_tau = np.zeros(Z_train.shape[1], dtype=float)
14921492
else:
14931493
init_tau = np.array([0.0])
14941494
forest_sampler_tau.prepare_for_sampler(

stochtree/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def __init__(
113113
raise ValueError("`leaf_dimension` must be an integer greater than 0")
114114
if leaf_model_scale is None:
115115
diag_value = 1.0 / num_trees
116-
leaf_model_scale_array = np.zeros((leaf_dimension, leaf_dimension), float)
116+
leaf_model_scale_array = np.zeros((leaf_dimension, leaf_dimension), dtype=float)
117117
np.fill_diagonal(leaf_model_scale_array, diag_value)
118118
else:
119119
if isinstance(leaf_model_scale, np.ndarray):
@@ -128,7 +128,7 @@ def __init__(
128128
"`leaf_model_scale` must be positive, if provided as scalar"
129129
)
130130
leaf_model_scale_array = np.zeros(
131-
(leaf_dimension, leaf_dimension), float
131+
(leaf_dimension, leaf_dimension), dtype=float
132132
)
133133
np.fill_diagonal(leaf_model_scale_array, leaf_model_scale)
134134
else:
@@ -278,7 +278,7 @@ def update_leaf_model_scale(
278278
"`leaf_model_scale` must be positive, if provided as scalar"
279279
)
280280
leaf_model_scale_array = np.zeros(
281-
(self.leaf_dimension, self.leaf_dimension), float
281+
(self.leaf_dimension, self.leaf_dimension), dtype=float
282282
)
283283
np.fill_diagonal(leaf_model_scale_array, leaf_model_scale)
284284
else:

0 commit comments

Comments
 (0)