Skip to content

Commit 31e9925

Browse files
authored
Merge pull request #962 from mindsdb/staging
Release 22.8.1.0
2 parents 5cf15bb + 258ad71 commit 31e9925

26 files changed

+1996
-1403
lines changed

.flake8

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
[flake8]
22
max-line-length = 120
3-
ignore = E402,F821,W503,W504,C408,W391
3+
ignore = E275,E402,F821,W503,W504,C408,W391
44
exclude = .git,__pycache__,docs,docssrc

docssrc/source/tutorials/custom_explainer/custom_explainer.ipynb

+1,238-1,238
Large diffs are not rendered by default.

lightwood/__about__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
__title__ = 'lightwood'
22
__package_name__ = 'lightwood'
3-
__version__ = '22.7.4.0'
3+
__version__ = '22.8.1.0'
44
__description__ = "Lightwood is a toolkit for automatic machine learning model building"
55
__email__ = "community@mindsdb.com"
66
__author__ = 'MindsDB Inc'

lightwood/analysis/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from lightwood.analysis.helpers.acc_stats import AccStats
99
from lightwood.analysis.helpers.conf_stats import ConfStats
1010
from lightwood.analysis.nn_conf.temp_scale import TempScaler
11-
from lightwood.analysis.helpers.feature_importance import GlobalFeatureImportance
11+
from lightwood.analysis.helpers.feature_importance import PermutationFeatureImportance
1212

1313
try:
1414
from lightwood.analysis.helpers.shap import ShapleyValues
@@ -17,4 +17,4 @@
1717

1818

1919
__all__ = ['model_analyzer', 'explain', 'BaseAnalysisBlock', 'TempScaler',
20-
'ICP', 'AccStats', 'ConfStats', 'GlobalFeatureImportance', 'ShapleyValues']
20+
'ICP', 'AccStats', 'ConfStats', 'PermutationFeatureImportance', 'ShapleyValues']

lightwood/analysis/helpers/acc_stats.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ class AccStats(BaseAnalysisBlock):
1515
""" Computes accuracy stats and a confusion matrix for the validation dataset """
1616

1717
def __init__(self, deps=('ICP',)):
18-
super().__init__(deps=deps) # @TODO: enforce that this actually prevents early execution somehow
18+
super().__init__(deps=deps)
19+
self.n_decimals = 3
1920

2021
def analyze(self, info: Dict[str, object], **kwargs) -> Dict[str, object]:
2122
ns = SimpleNamespace(**kwargs)
@@ -29,7 +30,7 @@ def analyze(self, info: Dict[str, object], **kwargs) -> Dict[str, object]:
2930
info['score_dict'] = evaluate_accuracy(ns.data, ns.normal_predictions['prediction'],
3031
ns.target, accuracy_functions, ts_analysis=ns.ts_analysis)
3132

32-
info['normal_accuracy'] = np.mean(list(info['score_dict'].values()))
33+
info['normal_accuracy'] = round(np.mean(list(info['score_dict'].values())), self.n_decimals)
3334
self.fit(ns, info['result_df'])
3435
info['val_overall_acc'], info['acc_histogram'], info['cm'], info['acc_samples'] = self.get_accuracy_stats()
3536
return info
@@ -99,7 +100,7 @@ def get_accuracy_stats(self, is_classification=None, is_numerical=None):
99100
for counts in list(bucket_acc_counts.values()):
100101
accuracy_count += counts
101102

102-
overall_accuracy = sum(accuracy_count) / len(accuracy_count)
103+
overall_accuracy = round(sum(accuracy_count) / len(accuracy_count), self.n_decimals)
103104

104105
for bucket in range(len(self.buckets)):
105106
if bucket not in bucket_accuracy:

lightwood/analysis/helpers/conf_stats.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def __init__(self, deps=('ICP',), ece_bins: int = 10):
1818
super().__init__(deps=deps)
1919
self.ece_bins = ece_bins
2020
self.ordenc = OrdinalEncoder()
21+
self.n_decimals = 3
2122

2223
def analyze(self, info: Dict[str, object], **kwargs) -> Dict[str, object]:
2324
ns = SimpleNamespace(**kwargs)
@@ -38,10 +39,10 @@ def analyze(self, info: Dict[str, object], **kwargs) -> Dict[str, object]:
3839
ns.data,
3940
ns.target,
4041
task_type)
41-
info['maximum_calibration_error'] = mce
42-
info['expected_calibration_error'] = ece
42+
info['maximum_calibration_error'] = round(mce, self.n_decimals)
43+
info['expected_calibration_error'] = round(ece, self.n_decimals)
4344
info['binned_conf_acc_difference'] = ces
44-
info['global_calibration_score'] = gscore
45+
info['global_calibration_score'] = round(gscore, self.n_decimals)
4546
return info
4647

4748
def _get_stats(self, confs, preds, data, target, task_type='categorical'):

lightwood/analysis/helpers/feature_importance.py

+66-34
Original file line numberDiff line numberDiff line change
@@ -2,69 +2,101 @@
22
from types import SimpleNamespace
33
from typing import Dict
44

5-
import torch
65
import numpy as np
6+
from sklearn.utils import shuffle
77

8+
from lightwood.helpers.log import log
9+
from lightwood.data.encoded_ds import EncodedDs
810
from lightwood.analysis.base import BaseAnalysisBlock
911
from lightwood.helpers.general import evaluate_accuracy
10-
from lightwood.analysis.nc.util import t_softmax
1112
from lightwood.api.types import PredictionArguments
1213

1314

14-
class GlobalFeatureImportance(BaseAnalysisBlock):
15+
class PermutationFeatureImportance(BaseAnalysisBlock):
1516
"""
16-
Analysis block that estimates column importance with a variant of the LOCO (leave-one-covariate-out) algorithm.
17+
Analysis block that estimates column importances via permutation.
1718
1819
Roughly speaking, the procedure:
1920
- iterates over all input columns
20-
- if the input column is optional, then make a predict with its values set to None
21-
- compare this accuracy with the accuracy obtained using all data
22-
- all accuracy differences are passed through a softmax and reported as estimated column importance scores
21+
- if the input column is optional, shuffle its values, then generate predictions for the input data
22+
- compare this accuracy with the accuracy obtained using unshuffled data
23+
- all accuracy differences are normalized with respect to the original accuracy (clipped at zero if negative)
24+
- report these as estimated column importance scores
2325
2426
Note that, crucially, this method does not refit the predictor at any point.
2527
28+
:param row_limit: Set to 0 to use the entire validation dataset.
29+
:param col_limit: Set to 0 to consider all possible columns.
30+
2631
Reference:
32+
https://scikit-learn.org/stable/modules/permutation_importance.html
2733
https://compstat-lmu.github.io/iml_methods_limitations/pfi.html
2834
"""
29-
def __init__(self, disable_column_importance):
30-
super().__init__()
35+
def __init__(self, disable_column_importance=False, row_limit=1000, col_limit=10, deps=tuple('AccStats',)):
36+
super().__init__(deps=deps)
3137
self.disable_column_importance = disable_column_importance
38+
self.row_limit = row_limit
39+
self.col_limit = col_limit
40+
self.n_decimals = 3
3241

3342
def analyze(self, info: Dict[str, object], **kwargs) -> Dict[str, object]:
3443
ns = SimpleNamespace(**kwargs)
3544

36-
if self.disable_column_importance or ns.tss.is_timeseries or ns.has_pretrained_text_enc:
45+
if self.disable_column_importance:
46+
info['column_importances'] = None
47+
elif ns.tss.is_timeseries or ns.has_pretrained_text_enc:
48+
log.warning(f"Block 'PermutationFeatureImportance' does not support time series nor text encoding, skipping...") # noqa
3749
info['column_importances'] = None
3850
else:
39-
empty_input_accuracy = {}
40-
ignorable_input_cols = [x for x in ns.input_cols if (not ns.tss.is_timeseries or
41-
(x != ns.tss.order_by and
42-
x not in ns.tss.historical_columns))]
43-
for col in ignorable_input_cols:
44-
partial_data = deepcopy(ns.encoded_val_data)
45-
partial_data.clear_cache()
46-
partial_data.data_frame[col] = [None] * len(partial_data.data_frame[col])
47-
48-
args = {'predict_proba': True} if ns.is_classification else {}
49-
empty_input_preds = ns.predictor(partial_data, args=PredictionArguments.from_dict(args))
50-
51-
empty_input_accuracy[col] = np.mean(list(evaluate_accuracy(
52-
ns.data,
53-
empty_input_preds['prediction'],
51+
if self.row_limit:
52+
log.info(f"[PFI] Using a random sample ({self.row_limit} rows out of {len(ns.encoded_val_data.data_frame)}).") # noqa
53+
ref_df = ns.encoded_val_data.data_frame.sample(frac=1).reset_index(drop=True).iloc[:self.row_limit]
54+
else:
55+
log.info(f"[PFI] Using complete validation set ({len(ns.encoded_val_data.data_frame)} rows).")
56+
ref_df = deepcopy(ns.encoded_val_data.data_frame)
57+
58+
ref_data = EncodedDs(ns.encoded_val_data.encoders, ref_df, ns.target)
59+
60+
args = {'predict_proba': True} if ns.is_classification else {}
61+
ref_preds = ns.predictor(ref_data, args=PredictionArguments.from_dict(args))
62+
ref_score = np.mean(list(evaluate_accuracy(ref_data.data_frame,
63+
ref_preds['prediction'],
64+
ns.target,
65+
ns.accuracy_functions
66+
).values()))
67+
shuffled_col_accuracy = {}
68+
shuffled_cols = []
69+
for x in ns.input_cols:
70+
if ('__mdb' not in x) and \
71+
(not ns.tss.is_timeseries or (x != ns.tss.order_by and x not in ns.tss.historical_columns)):
72+
shuffled_cols.append(x)
73+
74+
if self.col_limit:
75+
shuffled_cols = shuffled_cols[:min(self.col_limit, len(ns.encoded_val_data.data_frame.columns))]
76+
log.info(f"[PFI] Set to consider first {self.col_limit} columns out of {len(shuffled_cols)}: {shuffled_cols}.") # noqa
77+
else:
78+
log.info(f"[PFI] Computing importance for all {len(shuffled_cols)} columns: {shuffled_cols}")
79+
80+
for col in shuffled_cols:
81+
shuffle_data = deepcopy(ref_data)
82+
shuffle_data.clear_cache()
83+
shuffle_data.data_frame[col] = shuffle(shuffle_data.data_frame[col].values)
84+
85+
shuffled_preds = ns.predictor(shuffle_data, args=PredictionArguments.from_dict(args))
86+
shuffled_col_accuracy[col] = np.mean(list(evaluate_accuracy(
87+
shuffle_data.data_frame,
88+
shuffled_preds['prediction'],
5489
ns.target,
5590
ns.accuracy_functions
5691
).values()))
5792

5893
column_importances = {}
59-
acc_increases = []
60-
for col in ignorable_input_cols:
61-
accuracy_increase = (info['normal_accuracy'] - empty_input_accuracy[col])
62-
acc_increases.append(accuracy_increase)
63-
64-
# low 0.2 temperature to accentuate differences
65-
acc_increases = t_softmax(torch.Tensor([acc_increases]), t=0.2).tolist()[0]
66-
for col, inc in zip(ignorable_input_cols, acc_increases):
67-
column_importances[col] = inc # scores go from 0 to 1
94+
acc_increases = np.zeros((len(shuffled_cols),))
95+
for i, col in enumerate(shuffled_cols):
96+
accuracy_increase = (ref_score - shuffled_col_accuracy[col])
97+
acc_increases[i] = round(accuracy_increase, self.n_decimals)
98+
for col, inc in zip(shuffled_cols, acc_increases):
99+
column_importances[col] = inc
68100

69101
info['column_importances'] = column_importances
70102

lightwood/analysis/nc/base.py

+29
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,15 @@ def get_problem_type(cls) -> str:
2626
return 'classification'
2727

2828

29+
class TSMixin(object):
30+
def __init__(self) -> None:
31+
super(TSMixin, self).__init__()
32+
33+
@classmethod
34+
def get_problem_type(cls):
35+
return 'time-series'
36+
37+
2938
class BaseModelAdapter(BaseEstimator):
3039
__metaclass__ = abc.ABCMeta
3140

@@ -114,6 +123,14 @@ def _underlying_predict(self, x: np.array) -> np.array:
114123
return self.model.predict(x)
115124

116125

126+
class TSAdapter(BaseModelAdapter):
127+
def __init__(self, model: object, fit_params: Dict[str, object] = None) -> None:
128+
super(TSAdapter, self).__init__(model, fit_params)
129+
130+
def _underlying_predict(self, x: np.array) -> np.array:
131+
return self.model.predict(x)
132+
133+
117134
class CachedRegressorAdapter(RegressorAdapter):
118135
def __init__(self, model, fit_params=None):
119136
super(CachedRegressorAdapter, self).__init__(model, fit_params)
@@ -148,3 +165,15 @@ def predict(self, x=None):
148165
return t_softmax(self.prediction_cache, t=0.5)
149166
else:
150167
return self.prediction_cache
168+
169+
170+
class CachedTSAdapter(TSAdapter):
171+
def __init__(self, model, fit_params=None):
172+
super(CachedTSAdapter, self).__init__(model, fit_params)
173+
self.prediction_cache = None
174+
175+
def fit(self, x=None, y=None):
176+
pass
177+
178+
def predict(self, x=None):
179+
return self.prediction_cache

0 commit comments

Comments
 (0)