-
Notifications
You must be signed in to change notification settings - Fork 69
/
Copy pathutils_classical.py
258 lines (221 loc) · 9.89 KB
/
utils_classical.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
import spacy
import html
import string
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder
from IPython.core.display import display, HTML
from sklearn.feature_extraction.text import CountVectorizer
nlp = spacy.load("en_core_web_sm")
# Tokenizer is class instead of function to avoid multiple reloads of parser, stopwords and punctuation
# Uses spacy's inbuilt language tool for preprocessing
# in English [model](https://github.com/explosion/spaCy/tree/master/spacy/lang/en)
class BOWTokenizer:
"""Default tokenizer used by BOWEncoder for parsing and tokenizing.
"""
def __init__(
self,
parser,
stop_words=spacy.lang.en.stop_words.STOP_WORDS,
punctuations=string.punctuation,
):
"""Initialize the BOWTokenizer object.
Arguments:
parser {spacy.lang.en.English - by default} -- Any parser object
that supports parser(sentence) call on it.
Keyword Arguments:
stop_words {iterable over str} -- Set of stop words to be removed.
(default: {spacy.lang.en.stop_words.STOP_WORDS})
punctuations {iterable over str} -- Set of punctuations to be
removed. (default: {string.punctuation})
"""
self.parser = parser
# list of stop words and punctuation marks
self.stop_words = stop_words
self.punctuations = punctuations
def tokenize(self, sentence, keep_ids=False):
"""Returns the sentence (or prose) as a parsed list of tokens.
Arguments:
sentence {str} -- Single sentence/prose that needs to be tokenized.
Keyword Arguments:
keep_ids {bool} -- If True, returned tokens are indexed by their
original positions in the parsed sentence. If False, the returned
tokens do not preserve positionality. Has to be set to False for
training purposes. Set to true at text/execution time, when user
needs explanability. (default: {False})
Returns:
list -- List of all tokens extracted from the sentence.
"""
EMPTYTOKEN = "empty_token"
# Creating our token object, which is used to create documents with linguistic annotations.
mytokens = self.parser(sentence)
# Lemmatizing each token, removing blank space and converting each token into lowercase.
mytokens = [
word.lemma_.lower().strip() if word.lemma_ != "-PRON-" else word.lower_
for word in mytokens
]
# Removing stop words.
if keep_ids is True:
return [
word
if word not in self.stop_words and word not in self.punctuations
else EMPTYTOKEN
for word in mytokens
]
else:
return [
word
for word in mytokens
if word not in self.stop_words and word not in self.punctuations
]
def parse(self, sentence):
return self.parser(sentence)
class BOWEncoder:
"""Default encoder class with inbuilt function for decoding text that
has been encoded by the same object. Also supports label encoding.
Can be used as a skeleton to build more sophisticated encoders on top.
"""
def __init__(self):
"""Initializes the Encoder object and sets internal tokenizer,
labelEncoder and vectorizer using predefined objects.
"""
self.tokenizer = BOWTokenizer(
nlp
) # the tokenizer must have a tokenize() and parse() function.
self.labelEncoder = LabelEncoder()
self.vectorizer = CountVectorizer(
tokenizer=self.tokenizer.tokenize, ngram_range=(1, 1), min_df=1
)
self.decode_params = {}
# The keep_ids flag, is used by explain local in the explainer to decode
# importances over raw features.
def encode_features(self, X_str, needs_fit=True, keep_ids=False):
"""Encodes the dataset from string form to encoded vector form using
the tokenizer and vectorizer.
Arguments:
X_str {[iterable over strings]} -- The X data in string form.
Keyword Arguments:
needs_fit {bool} -- Whether the vectorizer itself needs to be
trained or not. (default: {True})
keep_ids {bool} -- Whether to preserve position of encoded words
with respect to raw document. Has to be False for training. Has to
be True for explanations and decoding.(default: {False})
Returns:
[List with 2 components] --
* X_vec -- The dataset vectorized and encoded to numeric form.
* self.vectorizer -- trained vectorizer.
"""
# encoding while preserving ids, used only for importance computation
# and not during training
if keep_ids is True and isinstance(X_str, str):
X_str = self.tokenizer.tokenize(X_str, keep_ids=True)
# needs_fit will be set to true if encoder is not already trained
if needs_fit is True:
self.vectorizer.fit(X_str)
if isinstance(X_str, str):
X_str = [X_str]
X_vec = self.vectorizer.transform(X_str)
return [X_vec, self.vectorizer]
def encode_labels(self, y_str, needs_fit=True):
"""Uses the default label encoder to encode labels into vector form.
Arguments:
y_str {Iterable over str} -- array-like w. label names as elements
Keyword Arguments:
needs_fit {bool} -- Does the label encoder need training.
(default: {True})
Returns:
[List with 2 components] --
* y_vec -- The labels vectorized and encoded to numeric form.
* self.labelEncoder -- trained label encoder object.
"""
y_str = np.asarray(y_str[:]).reshape(-1, 1)
if needs_fit is True:
y_vec = self.labelEncoder.fit_transform(y_str)
else:
y_vec = self.labelEncoder.transform(y_str)
return [y_vec, self.labelEncoder]
def decode_imp(self, encoded_imp, input_text):
"""Decodes importances over encoded features as importances over
raw features. Assumes the encoding was done with the same object.
Operates on a datapoint-by-datapoint basis.
Arguments:
encoded_imp {list} -- List of importances in order of
encoded features.
input_text {[list]} -- List containing raw text over which
importances are to be returned.
Returns:
[List with 2 components] --
* decoded_imp -- Importances with 1:1 mapping to parsed sent.
* parsed_sentence -- Raw text parsed as list with individual raw
features.
"""
EMPTYTOKEN = "empty_token"
parsed_sentence = []
# obtain parsed sentence, while preserving token -> position in sentence mapping
for i in self.tokenizer.parse(input_text):
parsed_sentence += [str(i)]
encoded_text = self.tokenizer.tokenize(input_text, keep_ids=True)
# replace words with an empty token if deleted when tokenizing
encoded_word_ids = [
None if word == EMPTYTOKEN else self.vectorizer.vocabulary_.get(word)
for word in encoded_text
]
# obtain word importance corresponding to the word vectors of the encoded sentence
decoded_imp = [
0 if idx is None else encoded_imp[idx] for idx in encoded_word_ids
]
return (decoded_imp, parsed_sentence)
def plot_local_imp(parsed_sentence, word_importances, max_alpha=0.5):
"""Plots the top importances for a parsed sentence when corresponding
importances are available.
Internal fast prototyping tool for easy visualization.
Serves as a visual proxy for dashboard.
Arguments:
parsed_sentence {[list]} -- Raw text parsed as list with individual raw
features.
word_importances {[list]} -- Importances with 1:1 mapping to parsed
sentences.
Keyword Arguments:
max_alpha {float} -- Changes intensity of coloring returned by the tool.
(default: {0.5})
"""
# Prevent special characters like & and < to cause the browser...
# to display something other than what you intended.
def html_escape(text):
return html.escape(text)
word_importances = 100.0 * word_importances / (np.sum(np.abs(word_importances)))
highlighted_text = []
for i, word in enumerate(parsed_sentence):
weight = word_importances[i]
if weight > 0:
highlighted_text.append(
'<span style="background-color:rgba(135,206,250,'
+ str(abs(weight) / max_alpha)
+ ');">'
+ html_escape(word)
+ "</span>"
)
elif weight < 0:
highlighted_text.append(
'<span style="background-color:rgba(250,0,0,'
+ str(abs(weight) / max_alpha)
+ ');">'
+ html_escape(word)
+ "</span>"
)
else:
highlighted_text.append(word)
highlighted_text = " ".join(highlighted_text)
display(HTML(highlighted_text))
def plot_global_imp(top_words, top_importances, label_name):
"""Plot top 20 global importances as a matplotlib bar graph.
Arguments:
top_words {list} -- Words with 1:1 mapping to top_importances.
top_importances {list} -- Top importance values for top words.
label_name {str} -- Label for which importances are being displayed.
"""
plt.figure(figsize=(14, 7))
plt.title("most important words for class label: " + str(label_name), fontsize=18)
plt.bar(range(len(top_importances)), top_importances, color="r", align="center")
plt.xticks(range(len(top_importances)), top_words, rotation=60, fontsize=18)
plt.show()