Skip to content

Commit eb67078

Browse files
authored
Merge pull request #79 from cdpierse/feature/multilabel-classification-explainer
MultiLabel Classification Explainer
2 parents 2f22c58 + 8e66166 commit eb67078

24 files changed

+570
-297
lines changed

.DS_Store

10 KB
Binary file not shown.

.pre-commit-config.yaml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
repos:
2+
- repo: https://github.com/pre-commit/pre-commit-hooks
3+
rev: v2.3.0
4+
hooks:
5+
- id: check-yaml
6+
- id: end-of-file-fixer
7+
- id: trailing-whitespace
8+
- repo: https://github.com/psf/black
9+
rev: 21.12b0
10+
hooks:
11+
- id: black
12+
args: [--line-length=120, --target-version=py38]

README.md

Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ Check out the streamlit [demo app here](https://share.streamlit.io/cdpierse/tran
3939
- [Sequence Classification Explainer](#sequence-classification-explainer)
4040
- [Visualize Classification attributions](#visualize-classification-attributions)
4141
- [Explaining Attributions for Non Predicted Class](#explaining-attributions-for-non-predicted-class)
42+
- [MultiLabel Classification Explainer](#sequence-classification-explainer)
43+
- [Visualize MultiLabel Classification attributions](#visualize-multilabel-attributions)
4244
- [Zero Shot Classification Explainer](#zero-shot-classification-explainer)
4345
- [Visualize Zero Shot Classification attributions](#visualize-zero-shot-classification-attributions)
4446
- [Question Answering Explainer (Experimental)](#question-answering-explainer-experimental)
@@ -173,6 +175,241 @@ Getting attributions for different classes is particularly insightful for multic
173175
For a detailed explanation of this example please checkout this [multiclass classification notebook.](notebooks/multiclass_classification_example.ipynb)
174176

175177

178+
</details>
179+
180+
### MultiLabel Classification Explainer
181+
182+
<details><summary>Click to expand</summary>
183+
184+
This explainer is an extension of the `SequenceClassificationExplainer` and is thus compatible with all sequence classification models from the Transformers package. The key change in this explainer is that it caclulates attributions for each label in the model's config and returns a dictionary of word attributions w.r.t to each label. The `visualize()` method also displays a table of attributions with attributions calculated per label.
185+
186+
```python
187+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
188+
from transformers_interpret import MultiLabelClassificationExplainer
189+
190+
model_name = "j-hartmann/emotion-english-distilroberta-base"
191+
model = AutoModelForSequenceClassification.from_pretrained(model_name)
192+
tokenizer = AutoTokenizer.from_pretrained(model_name)
193+
194+
195+
cls_explainer = MultiLabelClassificationExplainer(model, tokenizer)
196+
197+
198+
word_attributions = cls_explainer("There were many aspects of the film I liked, but it was frightening and gross in parts. My parents hated it.")
199+
```
200+
This produces a dictionary of word attributions mapping labels to a list of tuples for each word and it's attribution score.
201+
<details><summary>Click to see word attribution dictionary</summary>
202+
203+
```python
204+
>>> word_attributions
205+
{'anger': [('<s>', 0.0),
206+
('There', 0.09002208622000409),
207+
('were', -0.025129709879675187),
208+
('many', -0.028852677974079328),
209+
('aspects', -0.06341968013631565),
210+
('of', -0.03587626320752477),
211+
('the', -0.014813095892961287),
212+
('film', -0.14087587475098232),
213+
('I', 0.007367876912617766),
214+
('liked', -0.09816592066307557),
215+
(',', -0.014259517291745674),
216+
('but', -0.08087144668471376),
217+
('it', -0.10185214349220136),
218+
('was', -0.07132244710777856),
219+
('frightening', -0.4125361737439814),
220+
('and', -0.021761663818889918),
221+
('gross', -0.10423745223600908),
222+
('in', -0.02383646952201854),
223+
('parts', -0.027137622525091033),
224+
('.', -0.02960415694062459),
225+
('My', 0.05642774605113695),
226+
('parents', 0.11146648216326158),
227+
('hated', 0.8497975489280364),
228+
('it', 0.05358116678115284),
229+
('.', -0.013566277162080632),
230+
('', 0.09293256725788422),
231+
('</s>', 0.0)],
232+
'disgust': [('<s>', 0.0),
233+
('There', -0.035296263203072),
234+
('were', -0.010224922196739717),
235+
('many', -0.03747571761725605),
236+
('aspects', 0.007696321643436715),
237+
('of', 0.0026740873113235107),
238+
('the', 0.0025752851265661335),
239+
('film', -0.040890035285783645),
240+
('I', -0.014710007408208579),
241+
('liked', 0.025696806663391577),
242+
(',', -0.00739107098314569),
243+
('but', 0.007353791868893654),
244+
('it', -0.00821368234753605),
245+
('was', 0.005439709067819798),
246+
('frightening', -0.8135974168445725),
247+
('and', -0.002334953123414774),
248+
('gross', 0.2366024374426269),
249+
('in', 0.04314772995234148),
250+
('parts', 0.05590472194035334),
251+
('.', -0.04362554293972562),
252+
('My', -0.04252694977895808),
253+
('parents', 0.051580790911406944),
254+
('hated', 0.5067406070057585),
255+
('it', 0.0527491071885104),
256+
('.', -0.008280280618652273),
257+
('', 0.07412384603053103),
258+
('</s>', 0.0)],
259+
'fear': [('<s>', 0.0),
260+
('There', -0.019615758046045408),
261+
('were', 0.008033402634196246),
262+
('many', 0.027772367717635423),
263+
('aspects', 0.01334130725685673),
264+
('of', 0.009186049991879768),
265+
('the', 0.005828877177384549),
266+
('film', 0.09882910753644959),
267+
('I', 0.01753565003544039),
268+
('liked', 0.02062597344466885),
269+
(',', -0.004469530636560965),
270+
('but', -0.019660439408176984),
271+
('it', 0.0488084071292538),
272+
('was', 0.03830859527501167),
273+
('frightening', 0.9526443954511705),
274+
('and', 0.02535156284103706),
275+
('gross', -0.10635301961551227),
276+
('in', -0.019190425328209065),
277+
('parts', -0.01713006453323631),
278+
('.', 0.015043169035757302),
279+
('My', 0.017068079071414916),
280+
('parents', -0.0630781275517486),
281+
('hated', -0.23630028921273583),
282+
('it', -0.056057044429020306),
283+
('.', 0.0015102052077844612),
284+
('', -0.010045048665404609),
285+
('</s>', 0.0)],
286+
'joy': [('<s>', 0.0),
287+
('There', 0.04881772670614576),
288+
('were', -0.0379316152427468),
289+
('many', -0.007955371089444285),
290+
('aspects', 0.04437296429416574),
291+
('of', -0.06407011137335743),
292+
('the', -0.07331568926973099),
293+
('film', 0.21588462483311055),
294+
('I', 0.04885724513463952),
295+
('liked', 0.5309510543276107),
296+
(',', 0.1339765195225006),
297+
('but', 0.09394079060730279),
298+
('it', -0.1462792330432028),
299+
('was', -0.1358591558323458),
300+
('frightening', -0.22184169339341142),
301+
('and', -0.07504142930419291),
302+
('gross', -0.005472075984252812),
303+
('in', -0.0942152657437379),
304+
('parts', -0.19345218754215965),
305+
('.', 0.11096247277185402),
306+
('My', 0.06604512262645984),
307+
('parents', 0.026376541098236207),
308+
('hated', -0.4988319510231699),
309+
('it', -0.17532499366236615),
310+
('.', -0.022609976138939034),
311+
('', -0.43417114685294833),
312+
('</s>', 0.0)],
313+
'neutral': [('<s>', 0.0),
314+
('There', 0.045984598036642205),
315+
('were', 0.017142566357474697),
316+
('many', 0.011419348619472542),
317+
('aspects', 0.02558593440287365),
318+
('of', 0.0186162232003498),
319+
('the', 0.015616416841815963),
320+
('film', -0.021190511300570092),
321+
('I', -0.03572427925026324),
322+
('liked', 0.027062554960050455),
323+
(',', 0.02089914209290366),
324+
('but', 0.025872618597570115),
325+
('it', -0.002980407262316265),
326+
('was', -0.022218157611174086),
327+
('frightening', -0.2982516449116045),
328+
('and', -0.01604643529040792),
329+
('gross', -0.04573829263548096),
330+
('in', -0.006511536166676108),
331+
('parts', -0.011744224307968652),
332+
('.', -0.01817041167875332),
333+
('My', -0.07362312722231429),
334+
('parents', -0.06910711601816408),
335+
('hated', -0.9418903509267312),
336+
('it', 0.022201795222373488),
337+
('.', 0.025694319747309045),
338+
('', 0.04276690822325994),
339+
('</s>', 0.0)],
340+
'sadness': [('<s>', 0.0),
341+
('There', 0.028237893283377526),
342+
('were', -0.04489910545229568),
343+
('many', 0.004996044977269471),
344+
('aspects', -0.1231292680125582),
345+
('of', -0.04552690725956671),
346+
('the', -0.022077819961347042),
347+
('film', -0.14155752357877663),
348+
('I', 0.04135347872193571),
349+
('liked', -0.3097732540526099),
350+
(',', 0.045114660009053134),
351+
('but', 0.0963352125332619),
352+
('it', -0.08120617610094617),
353+
('was', -0.08516150809170213),
354+
('frightening', -0.10386889639962761),
355+
('and', -0.03931986389970189),
356+
('gross', -0.2145059013625132),
357+
('in', -0.03465423285571697),
358+
('parts', -0.08676627134611635),
359+
('.', 0.19025217371906333),
360+
('My', 0.2582092561303794),
361+
('parents', 0.15432351476960307),
362+
('hated', 0.7262186310977987),
363+
('it', -0.029160655114499095),
364+
('.', -0.002758524253450406),
365+
('', -0.33846410359182094),
366+
('</s>', 0.0)],
367+
'surprise': [('<s>', 0.0),
368+
('There', 0.07196110795254315),
369+
('were', 0.1434314520711312),
370+
('many', 0.08812238369489701),
371+
('aspects', 0.013432396769890982),
372+
('of', -0.07127508805657243),
373+
('the', -0.14079766624810955),
374+
('film', -0.16881201614906485),
375+
('I', 0.040595668935112135),
376+
('liked', 0.03239855530171577),
377+
(',', -0.17676382558158257),
378+
('but', -0.03797939330341559),
379+
('it', -0.029191325089641736),
380+
('was', 0.01758013584108571),
381+
('frightening', -0.221738963726823),
382+
('and', -0.05126920277135527),
383+
('gross', -0.33986913466614044),
384+
('in', -0.018180366628697),
385+
('parts', 0.02939418603252064),
386+
('.', 0.018080129971003226),
387+
('My', -0.08060162218059498),
388+
('parents', 0.04351719139081836),
389+
('hated', -0.6919028585285265),
390+
('it', 0.0009574844165327357),
391+
('.', -0.059473118237873344),
392+
('', -0.465690452620123),
393+
('</s>', 0.0)]}
394+
```
395+
</details>
396+
397+
398+
#### Visualize MultiLabel Classification attributions
399+
400+
Sometimes the numeric attributions can be difficult to read particularly in instances where there is a lot of text. To help with that we also provide the `visualize()` method that utilizes Captum's in built viz library to create a HTML file highlighting the attributions. For this explainer attributions will be show w.r.t to each label.
401+
402+
If you are in a notebook, calls to the `visualize()` method will display the visualization in-line. Alternatively you can pass a filepath in as an argument and an HTML file will be created, allowing you to view the explanation HTML in your browser.
403+
404+
```python
405+
cls_explainer.visualize("multilabel_viz.html")
406+
```
407+
408+
<a href="https://github.com/cdpierse/transformers-interpret/blob/master/images/multilabel_example.png">
409+
<img src="https://github.com/cdpierse/transformers-interpret/blob/master/images/multilabel_example.png" width="80%" height="80%" align="center"/>
410+
</a>
411+
412+
176413
</details>
177414

178415
### Zero Shot Classification Explainer

images/.DS_Store

6 KB
Binary file not shown.

images/multilabel_example.png

295 KB
Loading

notebooks/.DS_Store

6 KB
Binary file not shown.

notebooks/multiclass_classification_example.ipynb

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,7 @@
3131
"outputs": [],
3232
"source": [
3333
"tokenizer = AutoTokenizer.from_pretrained(\"sampathkethineedi/industry-classification\")\n",
34-
"model = AutoModelForSequenceClassification.from_pretrained(\n",
35-
" \"sampathkethineedi/industry-classification\"\n",
36-
")"
34+
"model = AutoModelForSequenceClassification.from_pretrained(\"sampathkethineedi/industry-classification\")"
3735
]
3836
},
3937
{

requirements.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
pytest==5.4.2
2-
captum==0.3.1
3-
transformers==4.3.2
4-
ipython==7.31.1
2+
captum==0.4.1
3+
transformers==4.15.0
4+
ipython==7.31.1

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
"test",
2323
]
2424
),
25-
version="0.5.2",
25+
version="0.6.0",
2626
license="Apache-2.0",
2727
description="Transformers Interpret is a model explainability tool designed to work exclusively with 🤗 transformers.",
2828
long_description=long_description,

test/.DS_Store

6 KB
Binary file not shown.

test/test_explainer.py

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,7 @@ def test_explainer_init_cuda():
123123

124124
def test_explainer_make_input_reference_pair():
125125
explainer = DummyExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER)
126-
input_ids, ref_input_ids, len_inputs = explainer._make_input_reference_pair(
127-
"this is a test string"
128-
)
126+
input_ids, ref_input_ids, len_inputs = explainer._make_input_reference_pair("this is a test string")
129127
assert isinstance(input_ids, Tensor)
130128
assert isinstance(ref_input_ids, Tensor)
131129
assert isinstance(len_inputs, int)
@@ -139,9 +137,7 @@ def test_explainer_make_input_reference_pair():
139137

140138
def test_explainer_make_input_reference_pair_gpt2():
141139
explainer = DummyExplainer(GPT2_MODEL, GPT2_TOKENIZER)
142-
input_ids, ref_input_ids, len_inputs = explainer._make_input_reference_pair(
143-
"this is a test string"
144-
)
140+
input_ids, ref_input_ids, len_inputs = explainer._make_input_reference_pair("this is a test string")
145141
assert isinstance(input_ids, Tensor)
146142
assert isinstance(ref_input_ids, Tensor)
147143
assert isinstance(len_inputs, int)
@@ -151,9 +147,7 @@ def test_explainer_make_input_reference_pair_gpt2():
151147

152148
def test_explainer_make_input_token_type_pair_no_sep_idx():
153149
explainer = DummyExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER)
154-
input_ids, ref_input_ids, len_inputs = explainer._make_input_reference_pair(
155-
"this is a test string"
156-
)
150+
input_ids, ref_input_ids, len_inputs = explainer._make_input_reference_pair("this is a test string")
157151
(
158152
token_type_ids,
159153
ref_token_type_ids,
@@ -169,9 +163,7 @@ def test_explainer_make_input_token_type_pair_no_sep_idx():
169163

170164
def test_explainer_make_input_token_type_pair_sep_idx():
171165
explainer = DummyExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER)
172-
input_ids, ref_input_ids, len_inputs = explainer._make_input_reference_pair(
173-
"this is a test string"
174-
)
166+
input_ids, ref_input_ids, len_inputs = explainer._make_input_reference_pair("this is a test string")
175167
(
176168
token_type_ids,
177169
ref_token_type_ids,
@@ -187,12 +179,8 @@ def test_explainer_make_input_token_type_pair_sep_idx():
187179

188180
def test_explainer_make_input_reference_position_id_pair():
189181
explainer = DummyExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER)
190-
input_ids, ref_input_ids, len_inputs = explainer._make_input_reference_pair(
191-
"this is a test string"
192-
)
193-
position_ids, ref_position_ids = explainer._make_input_reference_position_id_pair(
194-
input_ids
195-
)
182+
input_ids, ref_input_ids, len_inputs = explainer._make_input_reference_pair("this is a test string")
183+
position_ids, ref_position_ids = explainer._make_input_reference_position_id_pair(input_ids)
196184

197185
assert ref_position_ids[0][0] == torch.zeros(len(input_ids[0]))[0]
198186
for i, val in enumerate(position_ids[0]):
@@ -201,9 +189,7 @@ def test_explainer_make_input_reference_position_id_pair():
201189

202190
def test_explainer_make_attention_mask():
203191
explainer = DummyExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER)
204-
input_ids, ref_input_ids, len_inputs = explainer._make_input_reference_pair(
205-
"this is a test string"
206-
)
192+
input_ids, ref_input_ids, len_inputs = explainer._make_input_reference_pair("this is a test string")
207193
attention_mask = explainer._make_attention_mask(input_ids)
208194
assert len(attention_mask[0]) == len(input_ids[0])
209195
for i, val in enumerate(attention_mask[0]):

0 commit comments

Comments
 (0)