@@ -39,6 +39,8 @@ Check out the streamlit [demo app here](https://share.streamlit.io/cdpierse/tran
39
39
- [ Sequence Classification Explainer] ( #sequence-classification-explainer )
40
40
- [ Visualize Classification attributions] ( #visualize-classification-attributions )
41
41
- [ Explaining Attributions for Non Predicted Class] ( #explaining-attributions-for-non-predicted-class )
42
+ - [ MultiLabel Classification Explainer] ( #sequence-classification-explainer )
43
+ - [ Visualize MultiLabel Classification attributions] ( #visualize-multilabel-attributions )
42
44
- [ Zero Shot Classification Explainer] ( #zero-shot-classification-explainer )
43
45
- [ Visualize Zero Shot Classification attributions] ( #visualize-zero-shot-classification-attributions )
44
46
- [ Question Answering Explainer (Experimental)] ( #question-answering-explainer-experimental )
@@ -173,6 +175,241 @@ Getting attributions for different classes is particularly insightful for multic
173
175
For a detailed explanation of this example please checkout this [ multiclass classification notebook.] ( notebooks/multiclass_classification_example.ipynb )
174
176
175
177
178
+ </details >
179
+
180
+ ### MultiLabel Classification Explainer
181
+
182
+ <details ><summary >Click to expand</summary >
183
+
184
+ This explainer is an extension of the ` SequenceClassificationExplainer ` and is thus compatible with all sequence classification models from the Transformers package. The key change in this explainer is that it caclulates attributions for each label in the model's config and returns a dictionary of word attributions w.r.t to each label. The ` visualize() ` method also displays a table of attributions with attributions calculated per label.
185
+
186
+ ``` python
187
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
188
+ from transformers_interpret import MultiLabelClassificationExplainer
189
+
190
+ model_name = " j-hartmann/emotion-english-distilroberta-base"
191
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
192
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
193
+
194
+
195
+ cls_explainer = MultiLabelClassificationExplainer(model, tokenizer)
196
+
197
+
198
+ word_attributions = cls_explainer(" There were many aspects of the film I liked, but it was frightening and gross in parts. My parents hated it." )
199
+ ```
200
+ This produces a dictionary of word attributions mapping labels to a list of tuples for each word and it's attribution score.
201
+ <details ><summary >Click to see word attribution dictionary</summary >
202
+
203
+ ``` python
204
+ >> > word_attributions
205
+ {' anger' : [(' <s>' , 0.0 ),
206
+ (' There' , 0.09002208622000409 ),
207
+ (' were' , - 0.025129709879675187 ),
208
+ (' many' , - 0.028852677974079328 ),
209
+ (' aspects' , - 0.06341968013631565 ),
210
+ (' of' , - 0.03587626320752477 ),
211
+ (' the' , - 0.014813095892961287 ),
212
+ (' film' , - 0.14087587475098232 ),
213
+ (' I' , 0.007367876912617766 ),
214
+ (' liked' , - 0.09816592066307557 ),
215
+ (' ,' , - 0.014259517291745674 ),
216
+ (' but' , - 0.08087144668471376 ),
217
+ (' it' , - 0.10185214349220136 ),
218
+ (' was' , - 0.07132244710777856 ),
219
+ (' frightening' , - 0.4125361737439814 ),
220
+ (' and' , - 0.021761663818889918 ),
221
+ (' gross' , - 0.10423745223600908 ),
222
+ (' in' , - 0.02383646952201854 ),
223
+ (' parts' , - 0.027137622525091033 ),
224
+ (' .' , - 0.02960415694062459 ),
225
+ (' My' , 0.05642774605113695 ),
226
+ (' parents' , 0.11146648216326158 ),
227
+ (' hated' , 0.8497975489280364 ),
228
+ (' it' , 0.05358116678115284 ),
229
+ (' .' , - 0.013566277162080632 ),
230
+ (' ' , 0.09293256725788422 ),
231
+ (' </s>' , 0.0 )],
232
+ ' disgust' : [(' <s>' , 0.0 ),
233
+ (' There' , - 0.035296263203072 ),
234
+ (' were' , - 0.010224922196739717 ),
235
+ (' many' , - 0.03747571761725605 ),
236
+ (' aspects' , 0.007696321643436715 ),
237
+ (' of' , 0.0026740873113235107 ),
238
+ (' the' , 0.0025752851265661335 ),
239
+ (' film' , - 0.040890035285783645 ),
240
+ (' I' , - 0.014710007408208579 ),
241
+ (' liked' , 0.025696806663391577 ),
242
+ (' ,' , - 0.00739107098314569 ),
243
+ (' but' , 0.007353791868893654 ),
244
+ (' it' , - 0.00821368234753605 ),
245
+ (' was' , 0.005439709067819798 ),
246
+ (' frightening' , - 0.8135974168445725 ),
247
+ (' and' , - 0.002334953123414774 ),
248
+ (' gross' , 0.2366024374426269 ),
249
+ (' in' , 0.04314772995234148 ),
250
+ (' parts' , 0.05590472194035334 ),
251
+ (' .' , - 0.04362554293972562 ),
252
+ (' My' , - 0.04252694977895808 ),
253
+ (' parents' , 0.051580790911406944 ),
254
+ (' hated' , 0.5067406070057585 ),
255
+ (' it' , 0.0527491071885104 ),
256
+ (' .' , - 0.008280280618652273 ),
257
+ (' ' , 0.07412384603053103 ),
258
+ (' </s>' , 0.0 )],
259
+ ' fear' : [(' <s>' , 0.0 ),
260
+ (' There' , - 0.019615758046045408 ),
261
+ (' were' , 0.008033402634196246 ),
262
+ (' many' , 0.027772367717635423 ),
263
+ (' aspects' , 0.01334130725685673 ),
264
+ (' of' , 0.009186049991879768 ),
265
+ (' the' , 0.005828877177384549 ),
266
+ (' film' , 0.09882910753644959 ),
267
+ (' I' , 0.01753565003544039 ),
268
+ (' liked' , 0.02062597344466885 ),
269
+ (' ,' , - 0.004469530636560965 ),
270
+ (' but' , - 0.019660439408176984 ),
271
+ (' it' , 0.0488084071292538 ),
272
+ (' was' , 0.03830859527501167 ),
273
+ (' frightening' , 0.9526443954511705 ),
274
+ (' and' , 0.02535156284103706 ),
275
+ (' gross' , - 0.10635301961551227 ),
276
+ (' in' , - 0.019190425328209065 ),
277
+ (' parts' , - 0.01713006453323631 ),
278
+ (' .' , 0.015043169035757302 ),
279
+ (' My' , 0.017068079071414916 ),
280
+ (' parents' , - 0.0630781275517486 ),
281
+ (' hated' , - 0.23630028921273583 ),
282
+ (' it' , - 0.056057044429020306 ),
283
+ (' .' , 0.0015102052077844612 ),
284
+ (' ' , - 0.010045048665404609 ),
285
+ (' </s>' , 0.0 )],
286
+ ' joy' : [(' <s>' , 0.0 ),
287
+ (' There' , 0.04881772670614576 ),
288
+ (' were' , - 0.0379316152427468 ),
289
+ (' many' , - 0.007955371089444285 ),
290
+ (' aspects' , 0.04437296429416574 ),
291
+ (' of' , - 0.06407011137335743 ),
292
+ (' the' , - 0.07331568926973099 ),
293
+ (' film' , 0.21588462483311055 ),
294
+ (' I' , 0.04885724513463952 ),
295
+ (' liked' , 0.5309510543276107 ),
296
+ (' ,' , 0.1339765195225006 ),
297
+ (' but' , 0.09394079060730279 ),
298
+ (' it' , - 0.1462792330432028 ),
299
+ (' was' , - 0.1358591558323458 ),
300
+ (' frightening' , - 0.22184169339341142 ),
301
+ (' and' , - 0.07504142930419291 ),
302
+ (' gross' , - 0.005472075984252812 ),
303
+ (' in' , - 0.0942152657437379 ),
304
+ (' parts' , - 0.19345218754215965 ),
305
+ (' .' , 0.11096247277185402 ),
306
+ (' My' , 0.06604512262645984 ),
307
+ (' parents' , 0.026376541098236207 ),
308
+ (' hated' , - 0.4988319510231699 ),
309
+ (' it' , - 0.17532499366236615 ),
310
+ (' .' , - 0.022609976138939034 ),
311
+ (' ' , - 0.43417114685294833 ),
312
+ (' </s>' , 0.0 )],
313
+ ' neutral' : [(' <s>' , 0.0 ),
314
+ (' There' , 0.045984598036642205 ),
315
+ (' were' , 0.017142566357474697 ),
316
+ (' many' , 0.011419348619472542 ),
317
+ (' aspects' , 0.02558593440287365 ),
318
+ (' of' , 0.0186162232003498 ),
319
+ (' the' , 0.015616416841815963 ),
320
+ (' film' , - 0.021190511300570092 ),
321
+ (' I' , - 0.03572427925026324 ),
322
+ (' liked' , 0.027062554960050455 ),
323
+ (' ,' , 0.02089914209290366 ),
324
+ (' but' , 0.025872618597570115 ),
325
+ (' it' , - 0.002980407262316265 ),
326
+ (' was' , - 0.022218157611174086 ),
327
+ (' frightening' , - 0.2982516449116045 ),
328
+ (' and' , - 0.01604643529040792 ),
329
+ (' gross' , - 0.04573829263548096 ),
330
+ (' in' , - 0.006511536166676108 ),
331
+ (' parts' , - 0.011744224307968652 ),
332
+ (' .' , - 0.01817041167875332 ),
333
+ (' My' , - 0.07362312722231429 ),
334
+ (' parents' , - 0.06910711601816408 ),
335
+ (' hated' , - 0.9418903509267312 ),
336
+ (' it' , 0.022201795222373488 ),
337
+ (' .' , 0.025694319747309045 ),
338
+ (' ' , 0.04276690822325994 ),
339
+ (' </s>' , 0.0 )],
340
+ ' sadness' : [(' <s>' , 0.0 ),
341
+ (' There' , 0.028237893283377526 ),
342
+ (' were' , - 0.04489910545229568 ),
343
+ (' many' , 0.004996044977269471 ),
344
+ (' aspects' , - 0.1231292680125582 ),
345
+ (' of' , - 0.04552690725956671 ),
346
+ (' the' , - 0.022077819961347042 ),
347
+ (' film' , - 0.14155752357877663 ),
348
+ (' I' , 0.04135347872193571 ),
349
+ (' liked' , - 0.3097732540526099 ),
350
+ (' ,' , 0.045114660009053134 ),
351
+ (' but' , 0.0963352125332619 ),
352
+ (' it' , - 0.08120617610094617 ),
353
+ (' was' , - 0.08516150809170213 ),
354
+ (' frightening' , - 0.10386889639962761 ),
355
+ (' and' , - 0.03931986389970189 ),
356
+ (' gross' , - 0.2145059013625132 ),
357
+ (' in' , - 0.03465423285571697 ),
358
+ (' parts' , - 0.08676627134611635 ),
359
+ (' .' , 0.19025217371906333 ),
360
+ (' My' , 0.2582092561303794 ),
361
+ (' parents' , 0.15432351476960307 ),
362
+ (' hated' , 0.7262186310977987 ),
363
+ (' it' , - 0.029160655114499095 ),
364
+ (' .' , - 0.002758524253450406 ),
365
+ (' ' , - 0.33846410359182094 ),
366
+ (' </s>' , 0.0 )],
367
+ ' surprise' : [(' <s>' , 0.0 ),
368
+ (' There' , 0.07196110795254315 ),
369
+ (' were' , 0.1434314520711312 ),
370
+ (' many' , 0.08812238369489701 ),
371
+ (' aspects' , 0.013432396769890982 ),
372
+ (' of' , - 0.07127508805657243 ),
373
+ (' the' , - 0.14079766624810955 ),
374
+ (' film' , - 0.16881201614906485 ),
375
+ (' I' , 0.040595668935112135 ),
376
+ (' liked' , 0.03239855530171577 ),
377
+ (' ,' , - 0.17676382558158257 ),
378
+ (' but' , - 0.03797939330341559 ),
379
+ (' it' , - 0.029191325089641736 ),
380
+ (' was' , 0.01758013584108571 ),
381
+ (' frightening' , - 0.221738963726823 ),
382
+ (' and' , - 0.05126920277135527 ),
383
+ (' gross' , - 0.33986913466614044 ),
384
+ (' in' , - 0.018180366628697 ),
385
+ (' parts' , 0.02939418603252064 ),
386
+ (' .' , 0.018080129971003226 ),
387
+ (' My' , - 0.08060162218059498 ),
388
+ (' parents' , 0.04351719139081836 ),
389
+ (' hated' , - 0.6919028585285265 ),
390
+ (' it' , 0.0009574844165327357 ),
391
+ (' .' , - 0.059473118237873344 ),
392
+ (' ' , - 0.465690452620123 ),
393
+ (' </s>' , 0.0 )]}
394
+ ```
395
+ </details >
396
+
397
+
398
+ #### Visualize MultiLabel Classification attributions
399
+
400
+ Sometimes the numeric attributions can be difficult to read particularly in instances where there is a lot of text. To help with that we also provide the ` visualize() ` method that utilizes Captum's in built viz library to create a HTML file highlighting the attributions. For this explainer attributions will be show w.r.t to each label.
401
+
402
+ If you are in a notebook, calls to the ` visualize() ` method will display the visualization in-line. Alternatively you can pass a filepath in as an argument and an HTML file will be created, allowing you to view the explanation HTML in your browser.
403
+
404
+ ``` python
405
+ cls_explainer.visualize(" multilabel_viz.html" )
406
+ ```
407
+
408
+ <a href =" https://github.com/cdpierse/transformers-interpret/blob/master/images/multilabel_example.png " >
409
+ <img src =" https://github.com/cdpierse/transformers-interpret/blob/master/images/multilabel_example.png " width =" 80% " height =" 80% " align =" center " />
410
+ </a >
411
+
412
+
176
413
</details >
177
414
178
415
### Zero Shot Classification Explainer
0 commit comments