Skip to content

Commit deb65f3

Browse files
authored
Merge pull request #54 from cdpierse/feature/add-nsteps-param
Feature/add nsteps param Closes #51
2 parents 0cd3324 + d864150 commit deb65f3

8 files changed

+139
-6
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
"test",
2323
]
2424
),
25-
version="0.5.0",
25+
version="0.5.1",
2626
license="Apache-2.0",
2727
description="Transformers Interpret is a model explainability tool designed to work exclusively with 🤗 transformers.",
2828
long_description=long_description,

test/test_question_answering_explainer.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,3 +180,21 @@ def test_question_answering_visualize_save_append_html_file_ending():
180180
qa_explainer.visualize(html_filename)
181181
assert os.path.exists(html_filename + ".html")
182182
os.remove(html_filename + ".html")
183+
184+
185+
def test_question_answering_custom_steps():
186+
qa_explainer = QuestionAnsweringExplainer(
187+
DISTILBERT_QA_MODEL, DISTILBERT_QA_TOKENIZER
188+
)
189+
explainer_question = "what is his name ?"
190+
explainer_text = "his name is Bob"
191+
qa_explainer(explainer_question, explainer_text, n_steps=1)
192+
193+
194+
def test_question_answering_custom_internal_batch_size():
195+
qa_explainer = QuestionAnsweringExplainer(
196+
DISTILBERT_QA_MODEL, DISTILBERT_QA_TOKENIZER
197+
)
198+
explainer_question = "what is his name ?"
199+
explainer_text = "his name is Bob"
200+
qa_explainer(explainer_question, explainer_text, internal_batch_size=1)

test/test_sequence_classification_explainer.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ def test_sequence_classification_explainer_init_custom_labels_size_error():
6868
)
6969

7070

71-
7271
def test_sequence_classification_encode():
7372
seq_explainer = SequenceClassificationExplainer(
7473
DISTILBERT_MODEL, DISTILBERT_TOKENIZER
@@ -263,3 +262,19 @@ def test_sequence_classification_viz():
263262
)
264263
seq_explainer(explainer_string)
265264
seq_explainer.visualize()
265+
266+
267+
def sequence_classification_custom_steps():
268+
explainer_string = "I love you , I like you"
269+
seq_explainer = SequenceClassificationExplainer(
270+
DISTILBERT_MODEL, DISTILBERT_TOKENIZER
271+
)
272+
seq_explainer(explainer_string, n_steps=1)
273+
274+
275+
def sequence_classification_internal_batch_size():
276+
explainer_string = "I love you , I like you"
277+
seq_explainer = SequenceClassificationExplainer(
278+
DISTILBERT_MODEL, DISTILBERT_TOKENIZER
279+
)
280+
seq_explainer(explainer_string, internal_batch_size=1)

test/test_zero_shot_explainer.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,3 +185,29 @@ def test_zero_shot_model_lowercase_entailment():
185185
DISTILBERT_MNLI_MODEL,
186186
DISTILBERT_MNLI_TOKENIZER,
187187
)
188+
189+
190+
def test_zero_shot_custom_steps():
191+
zero_shot_explainer = ZeroShotClassificationExplainer(
192+
DISTILBERT_MNLI_MODEL,
193+
DISTILBERT_MNLI_TOKENIZER,
194+
)
195+
196+
zero_shot_explainer(
197+
"I have a problem with my iphone that needs to be resolved asap!!",
198+
labels=["urgent", " not", "urgent", "phone", "tablet", "computer"],
199+
n_steps=1,
200+
)
201+
202+
203+
def test_zero_shot_internal_batch_size():
204+
zero_shot_explainer = ZeroShotClassificationExplainer(
205+
DISTILBERT_MNLI_MODEL,
206+
DISTILBERT_MNLI_TOKENIZER,
207+
)
208+
209+
zero_shot_explainer(
210+
"I have a problem with my iphone that needs to be resolved asap!!",
211+
labels=["urgent", " not", "urgent", "phone", "tablet", "computer"],
212+
internal_batch_size=1,
213+
)

transformers_interpret/attributions.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ def __init__(
2929
position_ids: torch.Tensor = None,
3030
ref_token_type_ids: torch.Tensor = None,
3131
ref_position_ids: torch.Tensor = None,
32+
internal_batch_size: int = None,
33+
n_steps: int = 50,
3234
):
3335
super().__init__(custom_forward, embeddings, tokens)
3436
self.input_ids = input_ids
@@ -38,6 +40,8 @@ def __init__(
3840
self.position_ids = position_ids
3941
self.ref_token_type_ids = ref_token_type_ids
4042
self.ref_position_ids = ref_position_ids
43+
self.internal_batch_size = internal_batch_size
44+
self.n_steps = n_steps
4145

4246
self.lig = LayerIntegratedGradients(self.custom_forward, self.embeddings)
4347

@@ -51,6 +55,8 @@ def __init__(
5155
),
5256
return_convergence_delta=True,
5357
additional_forward_args=(self.attention_mask),
58+
internal_batch_size=self.internal_batch_size,
59+
n_steps=self.n_steps,
5460
)
5561
elif self.position_ids is not None:
5662
self._attributions, self.delta = self.lig.attribute(
@@ -61,6 +67,8 @@ def __init__(
6167
),
6268
return_convergence_delta=True,
6369
additional_forward_args=(self.attention_mask),
70+
internal_batch_size=self.internal_batch_size,
71+
n_steps=self.n_steps,
6472
)
6573
elif self.token_type_ids is not None:
6674
self._attributions, self.delta = self.lig.attribute(
@@ -71,13 +79,17 @@ def __init__(
7179
),
7280
return_convergence_delta=True,
7381
additional_forward_args=(self.attention_mask),
82+
internal_batch_size=self.internal_batch_size,
83+
n_steps=self.n_steps,
7484
)
7585

7686
else:
7787
self._attributions, self.delta = self.lig.attribute(
7888
inputs=self.input_ids,
7989
baselines=self.ref_input_ids,
8090
return_convergence_delta=True,
91+
internal_batch_size=self.internal_batch_size,
92+
n_steps=self.n_steps,
8193
)
8294

8395
@property

transformers_interpret/explainers/question_answering.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ def __init__(
5050

5151
self.position = 0
5252

53+
self.internal_batch_size = None
54+
self.n_steps = 50
55+
5356
def encode(self, text: str) -> list: # type: ignore
5457
"Encode 'text' using tokenizer, special tokens are not added"
5558
return self.tokenizer.encode(text, add_special_tokens=False)
@@ -320,6 +323,8 @@ def _calculate_attributions(self, embeddings: Embedding): # type: ignore
320323
ref_position_ids=self.ref_position_ids,
321324
token_type_ids=self.token_type_ids,
322325
ref_token_type_ids=self.ref_token_type_ids,
326+
internal_batch_size=self.internal_batch_size,
327+
n_steps=self.n_steps,
323328
)
324329
start_lig.summarize()
325330
self.start_attributions = start_lig
@@ -337,12 +342,21 @@ def _calculate_attributions(self, embeddings: Embedding): # type: ignore
337342
ref_position_ids=self.ref_position_ids,
338343
token_type_ids=self.token_type_ids,
339344
ref_token_type_ids=self.ref_token_type_ids,
345+
internal_batch_size=self.internal_batch_size,
346+
n_steps=self.n_steps,
340347
)
341348
end_lig.summarize()
342349
self.end_attributions = end_lig
343350
self.attributions = [self.start_attributions, self.end_attributions]
344351

345-
def __call__(self, question: str, text: str, embedding_type: int = 2) -> dict:
352+
def __call__(
353+
self,
354+
question: str,
355+
text: str,
356+
embedding_type: int = 2,
357+
internal_batch_size: int = None,
358+
n_steps: int = None,
359+
) -> dict:
346360
"""
347361
Calculates start and end position word attributions for `question` and `text` using the model
348362
and tokenizer given in the constructor.
@@ -357,9 +371,21 @@ def __call__(self, question: str, text: str, embedding_type: int = 2) -> dict:
357371
question (str): The question text
358372
text (str): The text or context from which the model finds an answers
359373
embedding_type (int, optional): The embedding type word(0), position(1), all(2) to calculate attributions for.
360-
Defaults to 2.
374+
Defaults to 2.
375+
internal_batch_size (int, optional): Divides total #steps * #examples
376+
data points into chunks of size at most internal_batch_size,
377+
which are computed (forward / backward passes)
378+
sequentially. If internal_batch_size is None, then all evaluations are
379+
processed in one batch.
380+
n_steps (int, optional): The number of steps used by the approximation
381+
method. Default: 50.
361382
362383
Returns:
363384
dict: Dict for start and end position word attributions.
364385
"""
386+
387+
if n_steps:
388+
self.n_steps = n_steps
389+
if internal_batch_size:
390+
self.internal_batch_size = internal_batch_size
365391
return self._run(question, text, embedding_type)

transformers_interpret/explainers/sequence_classification.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,9 @@ def __init__(
8181

8282
self._single_node_output = False
8383

84+
self.internal_batch_size = None
85+
self.n_steps = 50
86+
8487
@staticmethod
8588
def _get_id2label_and_label2id_dict(
8689
labels: List[str],
@@ -239,6 +242,8 @@ def _calculate_attributions( # type: ignore
239242
self.attention_mask,
240243
position_ids=self.position_ids,
241244
ref_position_ids=self.ref_position_ids,
245+
internal_batch_size=self.internal_batch_size,
246+
n_steps=self.n_steps,
242247
)
243248
lig.summarize()
244249
self.attributions = lig
@@ -279,6 +284,8 @@ def __call__(
279284
index: int = None,
280285
class_name: str = None,
281286
embedding_type: int = 0,
287+
internal_batch_size: int = None,
288+
n_steps: int = None,
282289
) -> list:
283290
"""
284291
Calculates attribution for `text` using the model
@@ -299,10 +306,21 @@ def __call__(
299306
index (int, optional): Optional output index to provide attributions for. Defaults to None.
300307
class_name (str, optional): Optional output class name to provide attributions for. Defaults to None.
301308
embedding_type (int, optional): The embedding type word(0) or position(1) to calculate attributions for. Defaults to 0.
302-
309+
internal_batch_size (int, optional): Divides total #steps * #examples
310+
data points into chunks of size at most internal_batch_size,
311+
which are computed (forward / backward passes)
312+
sequentially. If internal_batch_size is None, then all evaluations are
313+
processed in one batch.
314+
n_steps (int, optional): The number of steps used by the approximation
315+
method. Default: 50.
303316
Returns:
304317
list: List of tuples containing words and their associated attribution scores.
305318
"""
319+
320+
if n_steps:
321+
self.n_steps = n_steps
322+
if internal_batch_size:
323+
self.internal_batch_size = internal_batch_size
306324
return self._run(text, index, class_name, embedding_type=embedding_type)
307325

308326
def __str__(self):

transformers_interpret/explainers/zero_shot_classification.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@ def __init__(
6666
self.include_hypothesis = False
6767
self.attributions = []
6868

69+
self.internal_batch_size = None
70+
self.n_steps = 50
71+
6972
@property
7073
def word_attributions(self) -> dict:
7174
"Returns the word attributions for model and the text provided. Raises error if attributions not calculated."
@@ -235,6 +238,8 @@ def _calculate_attributions( # type: ignore
235238
ref_position_ids=self.ref_position_ids,
236239
token_type_ids=self.token_type_ids,
237240
ref_token_type_ids=self.ref_token_type_ids,
241+
internal_batch_size=self.internal_batch_size,
242+
n_steps=self.n_steps,
238243
)
239244
if self.include_hypothesis:
240245
lig.summarize()
@@ -249,6 +254,8 @@ def __call__(
249254
embedding_type: int = 0,
250255
hypothesis_template="this text is about {} .",
251256
include_hypothesis: bool = False,
257+
internal_batch_size: int = None,
258+
n_steps: int = None,
252259
) -> dict:
253260
"""
254261
Calculates attribution for `text` using the model and
@@ -285,10 +292,21 @@ def __call__(
285292
Defaults to "this text is about {} .".
286293
include_hypothesis (bool, optional): Alternative option to include hypothesis text in attributions
287294
and visualization. Defaults to False.
288-
295+
internal_batch_size (int, optional): Divides total #steps * #examples
296+
data points into chunks of size at most internal_batch_size,
297+
which are computed (forward / backward passes)
298+
sequentially. If internal_batch_size is None, then all evaluations are
299+
processed in one batch.
300+
n_steps (int, optional): The number of steps used by the approximation
301+
method. Default: 50.
289302
Returns:
290303
list: List of tuples containing words and their associated attribution scores.
291304
"""
305+
306+
if n_steps:
307+
self.n_steps = n_steps
308+
if internal_batch_size:
309+
self.internal_batch_size = internal_batch_size
292310
self.attributions = []
293311
self.pred_probs = []
294312
self.include_hypothesis = include_hypothesis

0 commit comments

Comments
 (0)