@@ -54,6 +54,7 @@ def __init__(
54
54
self .dtype_dict = dtype_dict
55
55
self .ts_analysis = ts_analysis
56
56
self .grouped_by = ['__default' ] if not ts_analysis ['tss' ].group_by else ts_analysis ['tss' ].group_by
57
+ self .group_boundaries = {} # stores last observed timestamp per series
57
58
self .train_args = train_args .get ('trainer_args' , {}) if train_args else {}
58
59
self .train_args ['early_stop_patience_steps' ] = self .train_args .get ('early_stop_patience_steps' , 10 )
59
60
self .conf_level = self .train_args .pop ('conf_level' , [90 ])
@@ -93,7 +94,8 @@ def fit(self, train_data: EncodedDs, dev_data: EncodedDs) -> None:
93
94
oby_col = self .ts_analysis ["tss" ].order_by
94
95
gby = self .ts_analysis ["tss" ].group_by if self .ts_analysis ["tss" ].group_by else []
95
96
df = deepcopy (cat_ds .data_frame )
96
- Y_df = self ._make_initial_df (df )
97
+ Y_df = self ._make_initial_df (df , mode = 'train' )
98
+ self .group_boundaries = self ._set_boundary (Y_df , gby )
97
99
if gby :
98
100
n_time = df [gby ].value_counts ().min ()
99
101
else :
@@ -130,9 +132,8 @@ def fit(self, train_data: EncodedDs, dev_data: EncodedDs) -> None:
130
132
log .info ('Successfully trained N-HITS forecasting model.' )
131
133
132
134
def partial_fit (self , train_data : EncodedDs , dev_data : EncodedDs , args : Optional [dict ] = None ) -> None :
133
- # TODO: reimplement this with automatic novel-row differential
134
135
self .hyperparam_search = False
135
- self .fit (dev_data , train_data ) # TODO: add support for passing args (e.g. n_epochs)
136
+ self .fit (train_data , dev_data ) # TODO: add support for passing args (e.g. n_epochs)
136
137
self .prepared = True
137
138
138
139
def __call__ (self , ds : Union [EncodedDs , ConcatedEncodedDs ],
@@ -183,7 +184,13 @@ def __call__(self, ds: Union[EncodedDs, ConcatedEncodedDs],
183
184
ydf ['confidence' ] = level / 100
184
185
return ydf
185
186
186
- def _make_initial_df (self , df ):
187
+ def _make_initial_df (self , df , mode = 'inference' ):
188
+ """
189
+ Prepares a dataframe for the NHITS model according to what neuralforecast expects.
190
+
191
+ If a per-group boundary exists, this method additionally drops out all observations prior to the cutoff.
192
+ """ # noqa
193
+
187
194
oby_col = self .ts_analysis ["tss" ].order_by
188
195
df = df .sort_values (by = f'__mdb_original_{ oby_col } ' )
189
196
df [f'__mdb_parsed_{ oby_col } ' ] = df .index
@@ -198,4 +205,31 @@ def _make_initial_df(self, df):
198
205
else :
199
206
Y_df ['unique_id' ] = '__default'
200
207
201
- return Y_df .reset_index ()
208
+ Y_df = Y_df .reset_index ()
209
+
210
+ # filter if boundary exists
211
+ if mode == 'train' and self .group_boundaries :
212
+ filtered = []
213
+ grouped = Y_df .groupby (by = 'unique_id' )
214
+ for group , sdf in grouped :
215
+ if group in self .group_boundaries :
216
+ sdf = sdf [sdf ['ds' ].gt (self .group_boundaries [group ])]
217
+ if sdf .shape [0 ] > 0 :
218
+ filtered .append (sdf )
219
+ Y_df = pd .concat (filtered )
220
+
221
+ return Y_df
222
+
223
+ @staticmethod
224
+ def _set_boundary (df : pd .DataFrame , gby : list ) -> Dict [str , object ]:
225
+ """
226
+ Finds last observation for every series in a pre-sorted `df` given a `gby` list of columns to group by.
227
+ """
228
+ if not gby :
229
+ group_boundaries = {'__default' : df .iloc [- 1 ]['ds' ]}
230
+ else :
231
+ # could use groupby().transform('max'), but we leverage pre-sorting instead
232
+ grouped_df = df .groupby (by = 'unique_id' , as_index = False ).last ()
233
+ group_boundaries = grouped_df [['unique_id' , 'ds' ]].set_index ('unique_id' ).to_dict ()['ds' ]
234
+
235
+ return group_boundaries
0 commit comments