@@ -76,20 +76,28 @@ def train_test_split(
76
76
random .seed (random_seed )
77
77
78
78
# Generate a random 256-bit key
79
- random_key = str (random .getrandbits (256 ))
79
+ random_str = str (random .getrandbits (256 ))
80
80
81
81
if isinstance (unique_key , str ):
82
82
unique_key = [unique_key ]
83
83
84
+ # Append random string to the name to avoid collision
85
+ combined_key = f"combined_key_{ random_str } "
86
+ train_flag = f"train_{ random_str } "
87
+
84
88
table = table .mutate (
85
- combined_key = ibis .literal ("," ).join (
86
- table [col ].cast ("str" ) for col in unique_key
87
- )
89
+ ** {
90
+ combined_key : ibis .literal ("," ).join (
91
+ table [col ].cast ("str" ) for col in unique_key
92
+ )
93
+ }
88
94
).mutate (
89
- train = (_ .combined_key + random_key ).hash ().abs () % num_buckets
90
- < int ((1 - test_size ) * num_buckets )
95
+ ** {
96
+ train_flag : (_ [combined_key ] + random_str ).hash ().abs () % num_buckets
97
+ < int ((1 - test_size ) * num_buckets )
98
+ }
91
99
)
92
100
93
- return table [table . train ] .drop ([" combined_key" , "train" ]), table [~ table . train ]. drop (
94
- [ "combined_key" , "train" ]
95
- )
101
+ return table [table [ train_flag ]] .drop ([combined_key , train_flag ]), table [
102
+ ~ table [ train_flag ]
103
+ ]. drop ([ combined_key , train_flag ] )
0 commit comments