@@ -58,14 +58,25 @@ def splitter(
58
58
train , dev , test = simple_split (data , pct_train , pct_dev , pct_test )
59
59
60
60
# Final assertions for time series
61
- window = tss .get ('window' , 1 ) if tss .get ('window' , 1 ) else 1
62
- horizon = tss .get ('horizon' , 1 ) if tss .get ('horizon' , 1 ) else 1
63
-
64
- if min (len (train ), len (dev )) < window :
65
- raise Exception (f"Dataset size is too small for the specified window size ({ window } )" )
66
-
67
- if min (len (train ), len (dev ), len (test )) < horizon :
68
- raise Exception (f"Dataset size is too small for the specified horizon size ({ horizon } )" )
61
+ if tss .get ('is_timeseries' , False ) not in (None , False ):
62
+ window = tss .get ('window' , 1 ) if tss .get ('window' , 1 ) else 1
63
+ horizon = tss .get ('horizon' , 1 ) if tss .get ('horizon' , 1 ) else 1
64
+
65
+ if all ([pct_train , pct_dev , pct_test ]) > 0.0 :
66
+ check_partitions = [train , dev , test ]
67
+ elif all ([pct_train , pct_test ]) > 0.0 :
68
+ check_partitions = [train , test ]
69
+ elif all ([pct_train , pct_dev ]) > 0.0 :
70
+ check_partitions = [train , dev ]
71
+ else :
72
+ check_partitions = [train ]
73
+ partition_lengths = [len (partition ) for partition in check_partitions ]
74
+
75
+ if min (partition_lengths ) < window :
76
+ raise Exception (f"Dataset too small for the specified window size ({ window } ). Partition length: { partition_lengths } " ) # noqa
77
+
78
+ if min (partition_lengths ) < horizon :
79
+ raise Exception (f"Dataset too small for the specified horizon size ({ horizon } ). Partition length: { partition_lengths } " ) # noqa
69
80
70
81
return {"train" : train , "test" : test , "dev" : dev , "stratified_on" : stratify_on }
71
82
0 commit comments