Skip to content

Commit 0dde958

Browse files
authored
Merge pull request #1171 from mindsdb/efficient_nhits_finetuning
feat: efficient nhits finetune
2 parents e22628a + 6cec6a7 commit 0dde958

File tree

1 file changed

+39
-5
lines changed

1 file changed

+39
-5
lines changed

lightwood/mixer/nhits.py

+39-5
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def __init__(
5454
self.dtype_dict = dtype_dict
5555
self.ts_analysis = ts_analysis
5656
self.grouped_by = ['__default'] if not ts_analysis['tss'].group_by else ts_analysis['tss'].group_by
57+
self.group_boundaries = {} # stores last observed timestamp per series
5758
self.train_args = train_args.get('trainer_args', {}) if train_args else {}
5859
self.train_args['early_stop_patience_steps'] = self.train_args.get('early_stop_patience_steps', 10)
5960
self.conf_level = self.train_args.pop('conf_level', [90])
@@ -93,7 +94,8 @@ def fit(self, train_data: EncodedDs, dev_data: EncodedDs) -> None:
9394
oby_col = self.ts_analysis["tss"].order_by
9495
gby = self.ts_analysis["tss"].group_by if self.ts_analysis["tss"].group_by else []
9596
df = deepcopy(cat_ds.data_frame)
96-
Y_df = self._make_initial_df(df)
97+
Y_df = self._make_initial_df(df, mode='train')
98+
self.group_boundaries = self._set_boundary(Y_df, gby)
9799
if gby:
98100
n_time = df[gby].value_counts().min()
99101
else:
@@ -130,9 +132,8 @@ def fit(self, train_data: EncodedDs, dev_data: EncodedDs) -> None:
130132
log.info('Successfully trained N-HITS forecasting model.')
131133

132134
def partial_fit(self, train_data: EncodedDs, dev_data: EncodedDs, args: Optional[dict] = None) -> None:
133-
# TODO: reimplement this with automatic novel-row differential
134135
self.hyperparam_search = False
135-
self.fit(dev_data, train_data) # TODO: add support for passing args (e.g. n_epochs)
136+
self.fit(train_data, dev_data) # TODO: add support for passing args (e.g. n_epochs)
136137
self.prepared = True
137138

138139
def __call__(self, ds: Union[EncodedDs, ConcatedEncodedDs],
@@ -183,7 +184,13 @@ def __call__(self, ds: Union[EncodedDs, ConcatedEncodedDs],
183184
ydf['confidence'] = level / 100
184185
return ydf
185186

186-
def _make_initial_df(self, df):
187+
def _make_initial_df(self, df, mode='inference'):
188+
"""
189+
Prepares a dataframe for the NHITS model according to what neuralforecast expects.
190+
191+
If a per-group boundary exists, this method additionally drops out all observations prior to the cutoff.
192+
""" # noqa
193+
187194
oby_col = self.ts_analysis["tss"].order_by
188195
df = df.sort_values(by=f'__mdb_original_{oby_col}')
189196
df[f'__mdb_parsed_{oby_col}'] = df.index
@@ -198,4 +205,31 @@ def _make_initial_df(self, df):
198205
else:
199206
Y_df['unique_id'] = '__default'
200207

201-
return Y_df.reset_index()
208+
Y_df = Y_df.reset_index()
209+
210+
# filter if boundary exists
211+
if mode == 'train' and self.group_boundaries:
212+
filtered = []
213+
grouped = Y_df.groupby(by='unique_id')
214+
for group, sdf in grouped:
215+
if group in self.group_boundaries:
216+
sdf = sdf[sdf['ds'].gt(self.group_boundaries[group])]
217+
if sdf.shape[0] > 0:
218+
filtered.append(sdf)
219+
Y_df = pd.concat(filtered)
220+
221+
return Y_df
222+
223+
@staticmethod
224+
def _set_boundary(df: pd.DataFrame, gby: list) -> Dict[str, object]:
225+
"""
226+
Finds last observation for every series in a pre-sorted `df` given a `gby` list of columns to group by.
227+
"""
228+
if not gby:
229+
group_boundaries = {'__default': df.iloc[-1]['ds']}
230+
else:
231+
# could use groupby().transform('max'), but we leverage pre-sorting instead
232+
grouped_df = df.groupby(by='unique_id', as_index=False).last()
233+
group_boundaries = grouped_df[['unique_id', 'ds']].set_index('unique_id').to_dict()['ds']
234+
235+
return group_boundaries

0 commit comments

Comments
 (0)