@@ -368,13 +368,13 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr
368
368
if all (isinstance (i , str ) for i in drop_vars_mu ):
369
369
if not np .all (np .isin (drop_vars_mu , X_train .columns )):
370
370
raise ValueError ("drop_vars_mu includes some variable names that are not in X_train" )
371
- variable_subset_mu = [i for i in X_train .shape [1 ] if drop_vars_mu .count (X_train .columns .array [i ]) == 0 ]
371
+ variable_subset_mu = [i for i in range ( X_train .shape [1 ]) if drop_vars_mu .count (X_train .columns .array [i ]) == 0 ]
372
372
elif all (isinstance (i , int ) for i in drop_vars_mu ):
373
373
if any (i >= X_train .shape [1 ] for i in drop_vars_mu ):
374
374
raise ValueError ("drop_vars_mu includes some variable indices that exceed the number of columns in X_train" )
375
375
if any (i < 0 for i in drop_vars_mu ):
376
376
raise ValueError ("drop_vars_mu includes some negative variable indices" )
377
- variable_subset_mu = [i for i in X_train .shape [1 ] if drop_vars_mu .count (i ) == 0 ]
377
+ variable_subset_mu = [i for i in range ( X_train .shape [1 ]) if drop_vars_mu .count (i ) == 0 ]
378
378
else :
379
379
raise ValueError ("drop_vars_mu must be a list of variable names (str) or column indices (int)" )
380
380
elif isinstance (drop_vars_mu , np .ndarray ):
@@ -399,7 +399,7 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr
399
399
if all (isinstance (i , str ) for i in keep_vars_tau ):
400
400
if not np .all (np .isin (keep_vars_tau , X_train .columns )):
401
401
raise ValueError ("keep_vars_tau includes some variable names that are not in X_train" )
402
- variable_subset_tau = [i for i in X_train .shape [1 ] if keep_vars_tau .count (X_train .columns .array [i ]) > 0 ]
402
+ variable_subset_tau = [i for i in range ( X_train .shape [1 ]) if keep_vars_tau .count (X_train .columns .array [i ]) > 0 ]
403
403
elif all (isinstance (i , int ) for i in keep_vars_tau ):
404
404
if any (i >= X_train .shape [1 ] for i in keep_vars_tau ):
405
405
raise ValueError ("keep_vars_tau includes some variable indices that exceed the number of columns in X_train" )
@@ -412,7 +412,7 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr
412
412
if keep_vars_tau .dtype == np .str_ :
413
413
if not np .all (np .isin (keep_vars_tau , X_train .columns )):
414
414
raise ValueError ("keep_vars_tau includes some variable names that are not in X_train" )
415
- variable_subset_tau = [i for i in X_train .shape [1 ] if keep_vars_tau .count (X_train .columns .array [i ]) > 0 ]
415
+ variable_subset_tau = [i for i in range ( X_train .shape [1 ]) if keep_vars_tau .count (X_train .columns .array [i ]) > 0 ]
416
416
else :
417
417
if np .any (keep_vars_tau >= X_train .shape [1 ]):
418
418
raise ValueError ("keep_vars_tau includes some variable indices that exceed the number of columns in X_train" )
@@ -426,13 +426,13 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr
426
426
if all (isinstance (i , str ) for i in drop_vars_tau ):
427
427
if not np .all (np .isin (drop_vars_tau , X_train .columns )):
428
428
raise ValueError ("drop_vars_tau includes some variable names that are not in X_train" )
429
- variable_subset_tau = [i for i in X_train .shape [1 ] if drop_vars_tau .count (X_train .columns .array [i ]) == 0 ]
429
+ variable_subset_tau = [i for i in range ( X_train .shape [1 ]) if drop_vars_tau .count (X_train .columns .array [i ]) == 0 ]
430
430
elif all (isinstance (i , int ) for i in drop_vars_tau ):
431
431
if any (i >= X_train .shape [1 ] for i in drop_vars_tau ):
432
432
raise ValueError ("drop_vars_tau includes some variable indices that exceed the number of columns in X_train" )
433
433
if any (i < 0 for i in drop_vars_tau ):
434
434
raise ValueError ("drop_vars_tau includes some negative variable indices" )
435
- variable_subset_tau = [i for i in X_train .shape [1 ] if drop_vars_tau .count (i ) == 0 ]
435
+ variable_subset_tau = [i for i in range ( X_train .shape [1 ]) if drop_vars_tau .count (i ) == 0 ]
436
436
else :
437
437
raise ValueError ("drop_vars_tau must be a list of variable names (str) or column indices (int)" )
438
438
elif isinstance (drop_vars_tau , np .ndarray ):
0 commit comments