Skip to content

Commit 3262859

Browse files
authored
Merge pull request #1160 from mindsdb/staging
Release 23.6.4.0
2 parents 51a3f03 + 7432c7a commit 3262859

File tree

19 files changed

+261
-1279
lines changed

19 files changed

+261
-1279
lines changed

docssrc/source/tutorials/custom_cleaner/custom_cleaner.ipynb

+27-723
Large diffs are not rendered by default.

lightwood/__about__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
__title__ = 'lightwood'
22
__package_name__ = 'lightwood'
3-
__version__ = '23.6.2.0'
3+
__version__ = '23.6.4.0'
44
__description__ = "Lightwood is a toolkit for automatic machine learning model building"
55
__email__ = "community@mindsdb.com"
66
__author__ = 'MindsDB Inc'

lightwood/api/json_ai.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -741,7 +741,6 @@ def _add_implicit_values(json_ai: JsonAI) -> JsonAI:
741741
"dtype_dict": "$dtype_dict",
742742
"target": "$target",
743743
"mode": "$mode",
744-
"ts_analysis": "$ts_analysis",
745744
"pred_args": "$pred_args",
746745
},
747746
},
@@ -1336,8 +1335,18 @@ def predict(self, data: pd.DataFrame, args: Dict = {{}}) -> pd.DataFrame:
13361335
black = None
13371336

13381337
if black is not None:
1338+
try:
1339+
formatted_predictor_code = black.format_str(predictor_code, mode=black.FileMode())
1340+
1341+
if type(predictor_from_code(formatted_predictor_code)).__name__ == 'Predictor':
1342+
predictor_code = formatted_predictor_code
1343+
else:
1344+
log.info('Black formatter output is invalid, predictor code might be a bit ugly')
1345+
1346+
except Exception:
1347+
log.info('Black formatter failed to run, predictor code might be a bit ugly')
1348+
else:
13391349
log.info('Unable to import black formatter, predictor code might be a bit ugly.')
1340-
predictor_code = black.format_str(predictor_code, mode=black.FileMode())
13411350

13421351
return predictor_code
13431352

lightwood/data/timeseries_analyzer.py

+15-117
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,11 @@
1-
from copy import deepcopy
2-
from typing import Dict, Tuple, List, Union
1+
from typing import Dict, Tuple, List
32

4-
import optuna
53
import numpy as np
64
import pandas as pd
7-
from sktime.transformations.series.detrend import Detrender
8-
from sktime.forecasting.trend import PolynomialTrendForecaster
9-
from sktime.transformations.series.detrend import ConditionalDeseasonalizer
5+
from type_infer.dtype import dtype
106

117
from lightwood.api.types import TimeseriesSettings
12-
from type_infer.dtype import dtype
13-
from lightwood.helpers.ts import get_ts_groups, get_delta, get_group_matches, Differencer
14-
from lightwood.helpers.log import log
8+
from lightwood.helpers.ts import get_ts_groups, get_delta, Differencer
159
from lightwood.encoder.time_series.helpers.common import generate_target_group_normalizers
1610

1711

@@ -36,18 +30,16 @@ def timeseries_analyzer(data: Dict[str, pd.DataFrame], dtype_dict: Dict[str, str
3630
""" # noqa
3731
tss = timeseries_settings
3832
groups = get_ts_groups(data['train'], tss)
39-
deltas, periods, freqs = get_delta(data['train'], dtype_dict, groups, target, tss)
33+
deltas, periods, freqs = get_delta(data['train'], tss)
4034

41-
normalizers = generate_target_group_normalizers(data['train'], target, dtype_dict, groups, tss)
35+
normalizers = generate_target_group_normalizers(data['train'], target, dtype_dict, tss)
4236

4337
if dtype_dict[target] in (dtype.integer, dtype.float, dtype.num_tsarray):
44-
naive_forecast_residuals, scale_factor = get_grouped_naive_residuals(data['dev'], target, tss, groups)
45-
differencers = get_differencers(data['train'], target, groups, tss.group_by)
46-
stl_transforms = get_stls(data['train'], data['dev'], target, periods, groups, tss)
38+
naive_forecast_residuals, scale_factor = get_grouped_naive_residuals(data['dev'], target, tss)
39+
differencers = get_differencers(data['train'], target, tss.group_by)
4740
else:
4841
naive_forecast_residuals, scale_factor = {}, {}
4942
differencers = {}
50-
stl_transforms = {}
5143

5244
return {'target_normalizers': normalizers,
5345
'deltas': deltas,
@@ -57,7 +49,7 @@ def timeseries_analyzer(data: Dict[str, pd.DataFrame], dtype_dict: Dict[str, str
5749
'ts_naive_mae': scale_factor,
5850
'periods': periods,
5951
'sample_freqs': freqs,
60-
'stl_transforms': stl_transforms,
52+
'stl_transforms': {}, # TODO: remove, or provide from outside as user perhaps
6153
'differencers': differencers
6254
}
6355

@@ -87,121 +79,27 @@ def get_naive_residuals(target_data: pd.DataFrame, m: int = 1) -> Tuple[List, fl
8779
def get_grouped_naive_residuals(
8880
info: pd.DataFrame,
8981
target: str,
90-
tss: TimeseriesSettings,
91-
group_combinations: List) -> Tuple[Dict, Dict]:
82+
tss: TimeseriesSettings
83+
) -> Tuple[Dict, Dict]:
9284
"""
9385
Wraps `get_naive_residuals` for a dataframe with multiple co-existing time series.
9486
""" # noqa
9587
group_residuals = {}
9688
group_scale_factors = {}
97-
for group in group_combinations:
98-
idxs, subset = get_group_matches(info, group, tss.group_by)
89+
grouped = info.groupby(by=tss.group_by) if tss.group_by else info.groupby(lambda x: '__default')
90+
for group, subset in grouped:
9991
if subset.shape[0] > 1:
10092
residuals, scale_factor = get_naive_residuals(subset[target]) # @TODO: pass m once we handle seasonality
10193
group_residuals[group] = residuals
10294
group_scale_factors[group] = scale_factor
10395
return group_residuals, group_scale_factors
10496

10597

106-
def get_differencers(data: pd.DataFrame, target: str, groups: List, group_cols: List):
98+
def get_differencers(data: pd.DataFrame, target: str, group_cols: List):
10799
differencers = {}
108-
for group in groups:
109-
idxs, subset = get_group_matches(data, group, group_cols)
100+
grouped = data.groupby(by=group_cols) if group_cols else data.groupby(lambda x: True)
101+
for group, subset in grouped:
110102
differencer = Differencer()
111103
differencer.fit(subset[target].values)
112104
differencers[group] = differencer
113105
return differencers
114-
115-
116-
def get_stls(train_df: pd.DataFrame,
117-
dev_df: pd.DataFrame,
118-
target: str,
119-
sps: Dict,
120-
groups: list,
121-
tss: TimeseriesSettings
122-
) -> Dict[str, object]:
123-
stls = {'__default': None}
124-
for group in groups:
125-
if group != '__default':
126-
_, tr_subset = get_group_matches(train_df, group, tss.group_by)
127-
_, dev_subset = get_group_matches(dev_df, group, tss.group_by)
128-
if tr_subset.shape[0] > 0 and dev_subset.shape[0] > 0 and sps.get(group, False):
129-
group_freq = tr_subset['__mdb_inferred_freq'].iloc[0]
130-
tr_subset = deepcopy(tr_subset)[target]
131-
dev_subset = deepcopy(dev_subset)[target]
132-
tr_subset.index = pd.date_range(start=tr_subset.iloc[0], freq=group_freq,
133-
periods=len(tr_subset)).to_period()
134-
dev_subset.index = pd.date_range(start=dev_subset.iloc[0], freq=group_freq,
135-
periods=len(dev_subset)).to_period()
136-
stl = _pick_ST(tr_subset, dev_subset, sps[group])
137-
log.info(f'Best STL decomposition params for group {group} are: {stl["best_params"]}')
138-
stls[group] = stl
139-
return stls
140-
141-
142-
def _pick_ST(tr_subset: pd.Series, dev_subset: pd.Series, sp: list):
143-
"""
144-
Perform hyperparam search with optuna to find best combination of ST transforms for a time series.
145-
146-
:param tr_subset: training series used for fitting blocks. Index should be datetime, and values are the actual time series.
147-
:param dev_subset: dev series used for computing loss. Index should be datetime, and values are the actual time series.
148-
:param sp: list of candidate seasonal periods
149-
:return: best deseasonalizer and detrender combination based on dev_loss
150-
""" # noqa
151-
152-
def _ST_objective(trial: optuna.Trial):
153-
trend_degree = trial.suggest_categorical("trend_degree", [1])
154-
ds_sp = trial.suggest_categorical("ds_sp", sp) # seasonality period to use in deseasonalizer
155-
if min(min(tr_subset), min(dev_subset)) <= 0:
156-
decomp_type = trial.suggest_categorical("decomp_type", ['additive'])
157-
else:
158-
decomp_type = trial.suggest_categorical("decomp_type", ['additive', 'multiplicative'])
159-
160-
detrender = Detrender(forecaster=PolynomialTrendForecaster(degree=trend_degree))
161-
deseasonalizer = ConditionalDeseasonalizer(sp=ds_sp, model=decomp_type)
162-
transformer = STLTransformer(detrender=detrender, deseasonalizer=deseasonalizer, type=decomp_type)
163-
transformer.fit(tr_subset)
164-
residuals = transformer.transform(dev_subset)
165-
166-
trial.set_user_attr("transformer", transformer)
167-
return np.power(residuals, 2).sum()
168-
169-
space = {"trend_degree": [1, 2], "ds_sp": sp, "decomp_type": ['additive', 'multiplicative']}
170-
study = optuna.create_study(sampler=optuna.samplers.GridSampler(space))
171-
study.optimize(_ST_objective, n_trials=8)
172-
173-
return {
174-
"transformer": study.best_trial.user_attrs['transformer'],
175-
"best_params": study.best_params
176-
}
177-
178-
179-
class STLTransformer:
180-
def __init__(self, detrender: Detrender, deseasonalizer: ConditionalDeseasonalizer, type: str = 'additive'):
181-
"""
182-
Class that handles STL transformation and inverse, given specific detrender and deseasonalizer instances.
183-
:param detrender: Already initialized.
184-
:param deseasonalizer: Already initialized.
185-
:param type: Either 'additive' or 'multiplicative'.
186-
""" # noqa
187-
self._type = type
188-
self.detrender = detrender
189-
self.deseasonalizer = deseasonalizer
190-
self.op = {
191-
'additive': lambda x, y: x - y,
192-
'multiplicative': lambda x, y: x / y
193-
}
194-
self.iop = {
195-
'additive': lambda x, y: x + y,
196-
'multiplicative': lambda x, y: x * y
197-
}
198-
199-
def fit(self, x: Union[pd.DataFrame, pd.Series]):
200-
self.deseasonalizer.fit(x)
201-
self.detrender.fit(self.op[self._type](x, self.deseasonalizer.transform(x)))
202-
203-
def transform(self, x: Union[pd.DataFrame, pd.Series]):
204-
return self.detrender.transform(self.deseasonalizer.transform(x))
205-
206-
def inverse_transform(self, x: Union[pd.DataFrame, pd.Series]):
207-
return self.deseasonalizer.inverse_transform(self.detrender.inverse_transform(x))

lightwood/data/timeseries_transform.py

+30-76
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,14 @@
55
import numpy as np
66
import pandas as pd
77
from lightwood.helpers.parallelism import get_nr_procs
8-
from lightwood.helpers.ts import get_ts_groups, get_delta, get_group_matches
98

109
from type_infer.dtype import dtype
1110
from lightwood.api.types import TimeseriesSettings, PredictionArguments
1211
from lightwood.helpers.log import log
1312

1413

1514
def transform_timeseries(
16-
data: pd.DataFrame, dtype_dict: Dict[str, str], ts_analysis: dict,
15+
data: pd.DataFrame, dtype_dict: Dict[str, str],
1716
timeseries_settings: TimeseriesSettings, target: str, mode: str,
1817
pred_args: Optional[PredictionArguments] = None
1918
) -> pd.DataFrame:
@@ -29,7 +28,6 @@ def transform_timeseries(
2928
3029
:param data: Dataframe with data to transform.
3130
:param dtype_dict: Dictionary with the types of each column.
32-
:param ts_analysis: dictionary with various insights into each series passed as training input.
3331
:param timeseries_settings: A `TimeseriesSettings` object.
3432
:param target: The name of the target column to forecast.
3533
:param mode: Either "train" or "predict", depending on what phase is calling this procedure.
@@ -43,6 +41,7 @@ def transform_timeseries(
4341
gb_arr = tss.group_by if tss.group_by is not None else []
4442
oby = tss.order_by
4543
window = tss.window
44+
oby_col = tss.order_by
4645

4746
if tss.use_previous_target and target not in data.columns:
4847
raise Exception(f"Cannot transform. Missing historical values for target column {target} (`use_previous_target` is set to True).") # noqa
@@ -51,37 +50,32 @@ def transform_timeseries(
5150
if hcol not in data.columns or data[hcol].isna().any():
5251
raise Exception(f"Cannot transform. Missing values in historical column {hcol}.")
5352

54-
# infer frequency with get_delta
55-
oby_col = tss.order_by
56-
groups = get_ts_groups(data, tss)
57-
58-
# initial stable sort and per-partition deduplication
53+
# initial stable sort and per-partition deduplication TODO: slowish, add a top-level param to disable if needed
5954
data = data.sort_values(by=oby_col, kind='mergesort')
6055
data = data.drop_duplicates(subset=[oby_col, *gb_arr], keep='first')
6156

62-
if not ts_analysis:
63-
_, periods, freqs = get_delta(data, dtype_dict, groups, target, tss)
64-
else:
65-
periods = ts_analysis['periods']
66-
freqs = ts_analysis['sample_freqs']
67-
6857
# pass seconds to timestamps according to each group's inferred freq, and force this freq on index
69-
subsets = []
70-
for group in groups:
71-
if (tss.group_by and group != '__default') or not tss.group_by:
72-
idxs, subset = get_group_matches(data, group, tss.group_by, copy=True)
73-
if subset.shape[0] > 0:
74-
if periods.get(group, periods['__default']) == 0 and subset.shape[0] > 1:
75-
raise Exception(
76-
f"Partition is not valid, faulty group {group}. Please make sure you group by a set of columns that ensures unique measurements for each grouping through time.") # noqa
77-
78-
index = pd.to_datetime(subset[oby_col], unit='s')
79-
subset.index = pd.date_range(start=index.iloc[0],
80-
freq=freqs.get(group, freqs['__default']),
81-
periods=len(subset))
82-
subset['__mdb_inferred_freq'] = subset.index.freq # sets constant column because pd.concat forgets freq (see: https://github.com/pandas-dev/pandas/issues/3232) # noqa
83-
subsets.append(subset)
84-
original_df = pd.concat(subsets).sort_values(by='__mdb_original_index')
58+
grouped = data.groupby(by=tss.group_by) if tss.group_by else data.groupby(lambda x: True)
59+
reindexed = []
60+
# TODO: introduce MP here
61+
for name, group in grouped:
62+
name = name if tss.group_by and len(tss.group_by) > 1 else (name, ) # guaranteed tuple type
63+
if group.shape[0] > 0:
64+
if group[tss.order_by].value_counts().max() > 1 and group.shape[0] > 1:
65+
raise Exception(f"Partition is not valid, faulty group {name}. Please make sure you group by a set of columns that ensures unique measurements for each grouping through time.") # noqa
66+
67+
index = pd.to_datetime(group[oby_col], unit='s', utc=True)
68+
group.index = pd.date_range(start=index.iloc[0], end=index.iloc[-1], periods=len(group))
69+
resampled = group
70+
group['__mdb_inferred_freq'] = None
71+
if len(group) > 2:
72+
freq = pd.infer_freq(group.index)
73+
if freq is not None:
74+
group['__mdb_inferred_freq'] = freq # sets constant column because pd.concat forgets freq (see: https://github.com/pandas-dev/pandas/issues/3232) # noqa
75+
resampled = group.resample(freq).first()
76+
reindexed.append(resampled)
77+
78+
original_df = pd.concat(reindexed).sort_values(by='__mdb_original_index')
8579

8680
if '__mdb_forecast_offset' in original_df.columns:
8781
""" This special column can be either None or an integer. If this column is passed, then the TS transformation will react to the values within:
@@ -103,18 +97,12 @@ def transform_timeseries(
10397
offset = 0
10498
cutoff_mode = False
10599

106-
original_index_list = []
107-
idx = 0
108-
for row in original_df.itertuples():
109-
if _make_pred(row) or cutoff_mode:
110-
original_df.at[row.Index, '__make_predictions'] = True
111-
original_index_list.append(idx)
112-
idx += 1
113-
else:
114-
original_df.at[row.Index, '__make_predictions'] = False
115-
original_index_list.append(None)
116-
117-
original_df['original_index'] = original_index_list
100+
if '__mdb_forecast_offset' in original_df.columns or cutoff_mode:
101+
original_df['__make_predictions'] = True
102+
original_df['original_index'] = np.arange(len(original_df))
103+
else:
104+
original_df['__make_predictions'] = False
105+
original_df['original_index'] = None
118106

119107
secondary_type_dict = {}
120108
if dtype_dict[oby] in (dtype.date, dtype.integer, dtype.float):
@@ -191,39 +179,12 @@ def transform_timeseries(
191179
else:
192180
raise Exception(f'Not enough historical context to make a timeseries prediction (`allow_incomplete_history` is set to False). Please provide a number of rows greater or equal to the window size - currently (number_rows, window_size) = ({min(group_lengths)}, {tss.window}). If you can\'t get enough rows, consider lowering your window size. If you want to force timeseries predictions lacking historical context please set the `allow_incomplete_history` timeseries setting to `True`, but this might lead to subpar predictions depending on the mixer.') # noqa
193181

194-
df_gb_map = None
195182
if n_groups > 1:
196183
df_gb_list = list(combined_df.groupby(tss.group_by))
197184
df_gb_map = {}
198185
for gb, df in df_gb_list:
199186
df_gb_map['_' + '_'.join(str(gb))] = df
200187

201-
timeseries_row_mapping = {}
202-
idx = 0
203-
204-
if df_gb_map is None:
205-
for i in range(len(combined_df)):
206-
row = combined_df.iloc[i]
207-
if not cutoff_mode:
208-
timeseries_row_mapping[idx] = int(
209-
row['original_index']) if row['original_index'] is not None and not np.isnan(
210-
row['original_index']) else None
211-
else:
212-
timeseries_row_mapping[idx] = idx
213-
idx += 1
214-
else:
215-
for gb in df_gb_map:
216-
for i in range(len(df_gb_map[gb])):
217-
row = df_gb_map[gb].iloc[i]
218-
if not cutoff_mode:
219-
timeseries_row_mapping[idx] = int(
220-
row['original_index']) if row['original_index'] is not None and not np.isnan(
221-
row['original_index']) else None
222-
else:
223-
timeseries_row_mapping[idx] = idx
224-
225-
idx += 1
226-
227188
del combined_df['original_index']
228189

229190
return combined_df
@@ -256,13 +217,6 @@ def _ts_infer_next_row(df: pd.DataFrame, ob: str) -> pd.DataFrame:
256217
return new_df
257218

258219

259-
def _make_pred(row) -> bool:
260-
"""
261-
Indicates whether a prediction should be made for `row` or not.
262-
"""
263-
return not hasattr(row, '__mdb_forecast_offset') or row.make_predictions
264-
265-
266220
def _ts_to_obj(df: pd.DataFrame, historical_columns: list) -> pd.DataFrame:
267221
"""
268222
Casts all historical columns in a dataframe to `object` type.

0 commit comments

Comments
 (0)