diff --git a/tests/utils.py b/tests/utils.py index 7498553..24af83f 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -60,9 +60,8 @@ def check_folder(output_root, epsilon, e_weight, f_weight, e_filename, f_filenam tests = sum(f_map[f] for f, s in data if s == "test") else: trains, tests = splits[-1].count("train"), splits[-1].count("test") - train_frac, test_frac = trains / (trains + tests), tests / (trains + tests) - assert 0.7 * (1 - epsilon) <= train_frac - assert 0.3 * (1 - epsilon) <= test_frac + assert int(0.7 * (1 - epsilon) * (trains + tests)) <= trains + assert int(0.3 * (1 - epsilon) * (trains + tests)) <= tests if n == "I": break