|
2 | 2 | from types import SimpleNamespace
|
3 | 3 | from typing import Dict
|
4 | 4 |
|
5 |
| -import torch |
6 | 5 | import numpy as np
|
| 6 | +from sklearn.utils import shuffle |
7 | 7 |
|
| 8 | +from lightwood.helpers.log import log |
| 9 | +from lightwood.data.encoded_ds import EncodedDs |
8 | 10 | from lightwood.analysis.base import BaseAnalysisBlock
|
9 | 11 | from lightwood.helpers.general import evaluate_accuracy
|
10 |
| -from lightwood.analysis.nc.util import t_softmax |
11 | 12 | from lightwood.api.types import PredictionArguments
|
12 | 13 |
|
13 | 14 |
|
14 |
| -class GlobalFeatureImportance(BaseAnalysisBlock): |
| 15 | +class PermutationFeatureImportance(BaseAnalysisBlock): |
15 | 16 | """
|
16 |
| - Analysis block that estimates column importance with a variant of the LOCO (leave-one-covariate-out) algorithm. |
| 17 | + Analysis block that estimates column importances via permutation. |
17 | 18 |
|
18 | 19 | Roughly speaking, the procedure:
|
19 | 20 | - iterates over all input columns
|
20 |
| - - if the input column is optional, then make a predict with its values set to None |
21 |
| - - compare this accuracy with the accuracy obtained using all data |
22 |
| - - all accuracy differences are passed through a softmax and reported as estimated column importance scores |
| 21 | + - if the input column is optional, shuffle its values, then generate predictions for the input data |
| 22 | + - compare this accuracy with the accuracy obtained using unshuffled data |
| 23 | + - all accuracy differences are normalized with respect to the original accuracy (clipped at zero if negative) |
| 24 | + - report these as estimated column importance scores |
23 | 25 |
|
24 | 26 | Note that, crucially, this method does not refit the predictor at any point.
|
25 | 27 |
|
| 28 | + :param row_limit: Set to 0 to use the entire validation dataset. |
| 29 | + :param col_limit: Set to 0 to consider all possible columns. |
| 30 | +
|
26 | 31 | Reference:
|
| 32 | + https://scikit-learn.org/stable/modules/permutation_importance.html |
27 | 33 | https://compstat-lmu.github.io/iml_methods_limitations/pfi.html
|
28 | 34 | """
|
29 |
| - def __init__(self, disable_column_importance): |
30 |
| - super().__init__() |
| 35 | + def __init__(self, disable_column_importance=False, row_limit=1000, col_limit=10, deps=tuple('AccStats',)): |
| 36 | + super().__init__(deps=deps) |
31 | 37 | self.disable_column_importance = disable_column_importance
|
| 38 | + self.row_limit = row_limit |
| 39 | + self.col_limit = col_limit |
| 40 | + self.n_decimals = 3 |
32 | 41 |
|
33 | 42 | def analyze(self, info: Dict[str, object], **kwargs) -> Dict[str, object]:
|
34 | 43 | ns = SimpleNamespace(**kwargs)
|
35 | 44 |
|
36 |
| - if self.disable_column_importance or ns.tss.is_timeseries or ns.has_pretrained_text_enc: |
| 45 | + if self.disable_column_importance: |
| 46 | + info['column_importances'] = None |
| 47 | + elif ns.tss.is_timeseries or ns.has_pretrained_text_enc: |
| 48 | + log.warning(f"Block 'PermutationFeatureImportance' does not support time series nor text encoding, skipping...") # noqa |
37 | 49 | info['column_importances'] = None
|
38 | 50 | else:
|
39 |
| - empty_input_accuracy = {} |
40 |
| - ignorable_input_cols = [x for x in ns.input_cols if (not ns.tss.is_timeseries or |
41 |
| - (x != ns.tss.order_by and |
42 |
| - x not in ns.tss.historical_columns))] |
43 |
| - for col in ignorable_input_cols: |
44 |
| - partial_data = deepcopy(ns.encoded_val_data) |
45 |
| - partial_data.clear_cache() |
46 |
| - partial_data.data_frame[col] = [None] * len(partial_data.data_frame[col]) |
47 |
| - |
48 |
| - args = {'predict_proba': True} if ns.is_classification else {} |
49 |
| - empty_input_preds = ns.predictor(partial_data, args=PredictionArguments.from_dict(args)) |
50 |
| - |
51 |
| - empty_input_accuracy[col] = np.mean(list(evaluate_accuracy( |
52 |
| - ns.data, |
53 |
| - empty_input_preds['prediction'], |
| 51 | + if self.row_limit: |
| 52 | + log.info(f"[PFI] Using a random sample ({self.row_limit} rows out of {len(ns.encoded_val_data.data_frame)}).") # noqa |
| 53 | + ref_df = ns.encoded_val_data.data_frame.sample(frac=1).reset_index(drop=True).iloc[:self.row_limit] |
| 54 | + else: |
| 55 | + log.info(f"[PFI] Using complete validation set ({len(ns.encoded_val_data.data_frame)} rows).") |
| 56 | + ref_df = deepcopy(ns.encoded_val_data.data_frame) |
| 57 | + |
| 58 | + ref_data = EncodedDs(ns.encoded_val_data.encoders, ref_df, ns.target) |
| 59 | + |
| 60 | + args = {'predict_proba': True} if ns.is_classification else {} |
| 61 | + ref_preds = ns.predictor(ref_data, args=PredictionArguments.from_dict(args)) |
| 62 | + ref_score = np.mean(list(evaluate_accuracy(ref_data.data_frame, |
| 63 | + ref_preds['prediction'], |
| 64 | + ns.target, |
| 65 | + ns.accuracy_functions |
| 66 | + ).values())) |
| 67 | + shuffled_col_accuracy = {} |
| 68 | + shuffled_cols = [] |
| 69 | + for x in ns.input_cols: |
| 70 | + if ('__mdb' not in x) and \ |
| 71 | + (not ns.tss.is_timeseries or (x != ns.tss.order_by and x not in ns.tss.historical_columns)): |
| 72 | + shuffled_cols.append(x) |
| 73 | + |
| 74 | + if self.col_limit: |
| 75 | + shuffled_cols = shuffled_cols[:min(self.col_limit, len(ns.encoded_val_data.data_frame.columns))] |
| 76 | + log.info(f"[PFI] Set to consider first {self.col_limit} columns out of {len(shuffled_cols)}: {shuffled_cols}.") # noqa |
| 77 | + else: |
| 78 | + log.info(f"[PFI] Computing importance for all {len(shuffled_cols)} columns: {shuffled_cols}") |
| 79 | + |
| 80 | + for col in shuffled_cols: |
| 81 | + shuffle_data = deepcopy(ref_data) |
| 82 | + shuffle_data.clear_cache() |
| 83 | + shuffle_data.data_frame[col] = shuffle(shuffle_data.data_frame[col].values) |
| 84 | + |
| 85 | + shuffled_preds = ns.predictor(shuffle_data, args=PredictionArguments.from_dict(args)) |
| 86 | + shuffled_col_accuracy[col] = np.mean(list(evaluate_accuracy( |
| 87 | + shuffle_data.data_frame, |
| 88 | + shuffled_preds['prediction'], |
54 | 89 | ns.target,
|
55 | 90 | ns.accuracy_functions
|
56 | 91 | ).values()))
|
57 | 92 |
|
58 | 93 | column_importances = {}
|
59 |
| - acc_increases = [] |
60 |
| - for col in ignorable_input_cols: |
61 |
| - accuracy_increase = (info['normal_accuracy'] - empty_input_accuracy[col]) |
62 |
| - acc_increases.append(accuracy_increase) |
63 |
| - |
64 |
| - # low 0.2 temperature to accentuate differences |
65 |
| - acc_increases = t_softmax(torch.Tensor([acc_increases]), t=0.2).tolist()[0] |
66 |
| - for col, inc in zip(ignorable_input_cols, acc_increases): |
67 |
| - column_importances[col] = inc # scores go from 0 to 1 |
| 94 | + acc_increases = np.zeros((len(shuffled_cols),)) |
| 95 | + for i, col in enumerate(shuffled_cols): |
| 96 | + accuracy_increase = (ref_score - shuffled_col_accuracy[col]) |
| 97 | + acc_increases[i] = round(accuracy_increase, self.n_decimals) |
| 98 | + for col, inc in zip(shuffled_cols, acc_increases): |
| 99 | + column_importances[col] = inc |
68 | 100 |
|
69 | 101 | info['column_importances'] = column_importances
|
70 | 102 |
|
|
0 commit comments