Skip to content

Commit 1d8762d

Browse files
authored
Merge pull request #60 from StochasticTree/feature_subset_bcf_hotfix
Updated feature subset code in python BCF interface
2 parents 0f61602 + ce517d0 commit 1d8762d

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

stochtree/bcf.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -368,13 +368,13 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr
368368
if all(isinstance(i, str) for i in drop_vars_mu):
369369
if not np.all(np.isin(drop_vars_mu, X_train.columns)):
370370
raise ValueError("drop_vars_mu includes some variable names that are not in X_train")
371-
variable_subset_mu = [i for i in X_train.shape[1] if drop_vars_mu.count(X_train.columns.array[i]) == 0]
371+
variable_subset_mu = [i for i in range(X_train.shape[1]) if drop_vars_mu.count(X_train.columns.array[i]) == 0]
372372
elif all(isinstance(i, int) for i in drop_vars_mu):
373373
if any(i >= X_train.shape[1] for i in drop_vars_mu):
374374
raise ValueError("drop_vars_mu includes some variable indices that exceed the number of columns in X_train")
375375
if any(i < 0 for i in drop_vars_mu):
376376
raise ValueError("drop_vars_mu includes some negative variable indices")
377-
variable_subset_mu = [i for i in X_train.shape[1] if drop_vars_mu.count(i) == 0]
377+
variable_subset_mu = [i for i in range(X_train.shape[1]) if drop_vars_mu.count(i) == 0]
378378
else:
379379
raise ValueError("drop_vars_mu must be a list of variable names (str) or column indices (int)")
380380
elif isinstance(drop_vars_mu, np.ndarray):
@@ -399,7 +399,7 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr
399399
if all(isinstance(i, str) for i in keep_vars_tau):
400400
if not np.all(np.isin(keep_vars_tau, X_train.columns)):
401401
raise ValueError("keep_vars_tau includes some variable names that are not in X_train")
402-
variable_subset_tau = [i for i in X_train.shape[1] if keep_vars_tau.count(X_train.columns.array[i]) > 0]
402+
variable_subset_tau = [i for i in range(X_train.shape[1]) if keep_vars_tau.count(X_train.columns.array[i]) > 0]
403403
elif all(isinstance(i, int) for i in keep_vars_tau):
404404
if any(i >= X_train.shape[1] for i in keep_vars_tau):
405405
raise ValueError("keep_vars_tau includes some variable indices that exceed the number of columns in X_train")
@@ -412,7 +412,7 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr
412412
if keep_vars_tau.dtype == np.str_:
413413
if not np.all(np.isin(keep_vars_tau, X_train.columns)):
414414
raise ValueError("keep_vars_tau includes some variable names that are not in X_train")
415-
variable_subset_tau = [i for i in X_train.shape[1] if keep_vars_tau.count(X_train.columns.array[i]) > 0]
415+
variable_subset_tau = [i for i in range(X_train.shape[1]) if keep_vars_tau.count(X_train.columns.array[i]) > 0]
416416
else:
417417
if np.any(keep_vars_tau >= X_train.shape[1]):
418418
raise ValueError("keep_vars_tau includes some variable indices that exceed the number of columns in X_train")
@@ -426,13 +426,13 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr
426426
if all(isinstance(i, str) for i in drop_vars_tau):
427427
if not np.all(np.isin(drop_vars_tau, X_train.columns)):
428428
raise ValueError("drop_vars_tau includes some variable names that are not in X_train")
429-
variable_subset_tau = [i for i in X_train.shape[1] if drop_vars_tau.count(X_train.columns.array[i]) == 0]
429+
variable_subset_tau = [i for i in range(X_train.shape[1]) if drop_vars_tau.count(X_train.columns.array[i]) == 0]
430430
elif all(isinstance(i, int) for i in drop_vars_tau):
431431
if any(i >= X_train.shape[1] for i in drop_vars_tau):
432432
raise ValueError("drop_vars_tau includes some variable indices that exceed the number of columns in X_train")
433433
if any(i < 0 for i in drop_vars_tau):
434434
raise ValueError("drop_vars_tau includes some negative variable indices")
435-
variable_subset_tau = [i for i in X_train.shape[1] if drop_vars_tau.count(i) == 0]
435+
variable_subset_tau = [i for i in range(X_train.shape[1]) if drop_vars_tau.count(i) == 0]
436436
else:
437437
raise ValueError("drop_vars_tau must be a list of variable names (str) or column indices (int)")
438438
elif isinstance(drop_vars_tau, np.ndarray):

0 commit comments

Comments
 (0)