Skip to content

Commit 6582682

Browse files
jitingxu1deepyaman
andauthored
fix(utils): fix possible name collision in train_test_split (ibis-project#142)
Co-authored-by: Deepyaman Datta <deepyaman.datta@utexas.edu>
1 parent a183f44 commit 6582682

File tree

1 file changed

+17
-9
lines changed

1 file changed

+17
-9
lines changed

ibis_ml/utils/_split.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -76,20 +76,28 @@ def train_test_split(
7676
random.seed(random_seed)
7777

7878
# Generate a random 256-bit key
79-
random_key = str(random.getrandbits(256))
79+
random_str = str(random.getrandbits(256))
8080

8181
if isinstance(unique_key, str):
8282
unique_key = [unique_key]
8383

84+
# Append random string to the name to avoid collision
85+
combined_key = f"combined_key_{random_str}"
86+
train_flag = f"train_{random_str}"
87+
8488
table = table.mutate(
85-
combined_key=ibis.literal(",").join(
86-
table[col].cast("str") for col in unique_key
87-
)
89+
**{
90+
combined_key: ibis.literal(",").join(
91+
table[col].cast("str") for col in unique_key
92+
)
93+
}
8894
).mutate(
89-
train=(_.combined_key + random_key).hash().abs() % num_buckets
90-
< int((1 - test_size) * num_buckets)
95+
**{
96+
train_flag: (_[combined_key] + random_str).hash().abs() % num_buckets
97+
< int((1 - test_size) * num_buckets)
98+
}
9199
)
92100

93-
return table[table.train].drop(["combined_key", "train"]), table[~table.train].drop(
94-
["combined_key", "train"]
95-
)
101+
return table[table[train_flag]].drop([combined_key, train_flag]), table[
102+
~table[train_flag]
103+
].drop([combined_key, train_flag])

0 commit comments

Comments
 (0)