Skip to content

Commit 76e1118

Browse files
authored
Merge pull request #792 from mindsdb/fix_delta_ungrouped
Fix: temporal delta estimation for ungrouped series
2 parents ee739ec + 8e36a0c commit 76e1118

File tree

3 files changed

+9
-1
lines changed

3 files changed

+9
-1
lines changed

lightwood/data/timeseries_analyzer.py

+1
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def get_delta(df: pd.DataFrame, ts_info: dict, group_combinations: list, order_c
8282
# get default delta for all data
8383
for col in order_cols:
8484
series = pd.Series([x[-1] for x in df[col]])
85+
series = series.drop_duplicates() # by this point df is ordered so duplicate timestamps are either because of non-handled groups or repeated data that, for mode delta estimation, should be ignored # noqa
8586
rolling_diff = series.rolling(window=2).apply(lambda x: x.iloc[1] - x.iloc[0])
8687
delta = rolling_diff.value_counts(ascending=False).keys()[0] # pick most popular
8788
deltas["__default"][col] = delta

lightwood/data/timeseries_transform.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def transform_timeseries(
106106
df_arr.append(df.sort_values(by=ob_arr))
107107
group_lengths.append(len(df))
108108
else:
109-
df_arr = [original_df]
109+
df_arr = [original_df.sort_values(by=ob_arr)]
110110
group_lengths.append(len(original_df))
111111

112112
n_groups = len(df_arr)

tests/integration/advanced/test_timeseries.py

+7
Original file line numberDiff line numberDiff line change
@@ -137,9 +137,16 @@ def test_1_time_series_regression(self):
137137

138138
# test inferring mode
139139
test_df['__mdb_make_predictions'] = False
140+
test_df = test_df.sample(frac=1) # shuffle to test internal ordering logic
140141
preds = pred.predict(test_df)
141142
self.check_ts_prediction_df(preds, nr_preds, [order_by])
142143

144+
# Additionally, check timestamps are further into the future than test dates
145+
latest_timestamp = pd.to_datetime(test_df[order_by]).max().timestamp()
146+
for idx, row in preds.iterrows():
147+
for timestamp in row[f'order_{order_by}']:
148+
assert timestamp > latest_timestamp
149+
143150
def test_2_time_series_classification(self):
144151
from lightwood.api.high_level import predictor_from_problem
145152

0 commit comments

Comments
 (0)