Skip to content

Commit 80881f8

Browse files
authored
fix(utils): remove redundant column in train_test_split() (ibis-project#131)
1 parent f0263b9 commit 80881f8

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

ibis_ml/utils/_split.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,6 @@ def train_test_split(
9090
< int((1 - test_size) * num_buckets)
9191
)
9292

93-
return table[table.train].drop(["combined_key"]), table[~table.train].drop(
94-
["combined_key"]
93+
return table[table.train].drop(["combined_key", "train"]), table[~table.train].drop(
94+
["combined_key", "train"]
9595
)

tests/test_train_test_split.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ def test_train_test_split():
1616
# Check counts and overlaps in train and test dataset
1717
assert train_table.count().execute() + test_table.count().execute() == N
1818
assert train_table.intersect(test_table).count().execute() == 0
19+
assert set(train_table.columns) == set(table.columns)
20+
assert set(test_table.columns) == set(table.columns)
1921

2022
# Check reproducibility
2123
reproduced_train_table, reproduced_test_table = ml.train_test_split(

0 commit comments

Comments
 (0)