From 96add78a4ffe5b6c436ea2d4cb08ca943418047c Mon Sep 17 00:00:00 2001 From: Gabriele Sarti Date: Wed, 15 May 2024 11:34:27 +0200 Subject: [PATCH] Add possibility to skip special tokens during attribution (#275) * Add possibility to skip special tokens during attribution * Bump pillow * Bump idna --- inseq/attr/feat/feature_attribution.py | 16 ++++++++++++++-- inseq/attr/step_functions.py | 14 ++++++++++++++ inseq/data/attribution.py | 10 ++++++++-- inseq/models/attribution_model.py | 23 +++++++++++++++++++---- inseq/models/decoder_only.py | 2 ++ inseq/models/encoder_decoder.py | 3 +++ inseq/models/huggingface_model.py | 2 +- inseq/utils/contrast_utils.py | 8 +++++++- requirements-dev.txt | 4 ++-- requirements.txt | 4 ++-- 10 files changed, 72 insertions(+), 14 deletions(-) diff --git a/inseq/attr/feat/feature_attribution.py b/inseq/attr/feat/feature_attribution.py index 3d832bc3..8dc4cfc8 100644 --- a/inseq/attr/feat/feature_attribution.py +++ b/inseq/attr/feat/feature_attribution.py @@ -176,6 +176,7 @@ def prepare_and_attribute( attribute_target: bool = False, step_scores: list[str] = [], include_eos_baseline: bool = False, + skip_special_tokens: bool = False, attributed_fn: Union[str, Callable[..., SingleScorePerStepTensor], None] = None, attribution_args: dict[str, Any] = {}, attributed_fn_args: dict[str, Any] = {}, @@ -206,6 +207,8 @@ def prepare_and_attribute( step scores can be added by using the :meth:`~inseq.register_step_function` function. include_eos_baseline (:obj:`bool`, `optional`): Whether to include the EOS token in the baseline for attribution. By default the EOS token is not used for attribution. Defaults to False. + skip_special_tokens (:obj:`bool`, `optional`): Whether to skip special tokens when encoding the input. + Defaults to False. attributed_fn (:obj:`str` or :obj:`Callable[..., SingleScorePerStepTensor]`, `optional`): The identifier or function of model outputs representing what should be attributed (e.g. output probits of model best prediction after softmax). If it is a string, it must be a valid function. @@ -224,12 +227,14 @@ def prepare_and_attribute( inputs = (sources, targets) if not self.attribution_model.is_encoder_decoder: inputs = targets - encoded_sources = self.attribution_model.encode(sources, return_baseline=True) + encoded_sources = self.attribution_model.encode( + sources, return_baseline=True, add_special_tokens=not skip_special_tokens + ) # We do this here to support separate attr_pos_start for different sentences when batching if attr_pos_start is None or attr_pos_start < encoded_sources.input_ids.shape[1]: attr_pos_start = encoded_sources.input_ids.shape[1] batch = self.attribution_model.formatter.prepare_inputs_for_attribution( - self.attribution_model, inputs, include_eos_baseline + self.attribution_model, inputs, include_eos_baseline, skip_special_tokens ) # If prepare_and_attribute was called from AttributionModel.attribute, # attributed_fn is already a Callable. Keep here to allow for usage independently @@ -245,6 +250,7 @@ def prepare_and_attribute( output_step_attributions=output_step_attributions, attribute_target=attribute_target, step_scores=step_scores, + skip_special_tokens=skip_special_tokens, attribution_args=attribution_args, attributed_fn_args=attributed_fn_args, step_scores_args=step_scores_args, @@ -310,6 +316,7 @@ def format_contrastive_targets( step_scores_args: dict[str, Any], attr_pos_start: int, attr_pos_end: int, + skip_special_tokens: bool = False, ) -> tuple[Optional[DecoderOnlyBatch], Optional[list[list[tuple[int, int]]]], dict[str, Any], dict[str, Any]]: contrast_batch, contrast_targets_alignments = None, None contrast_targets = attributed_fn_args.get("contrast_targets", None) @@ -327,6 +334,7 @@ def format_contrastive_targets( attribution_model=self.attribution_model, inputs=contrast_targets, as_targets=as_targets, + skip_special_tokens=skip_special_tokens, ) contrast_batch = DecoderOnlyBatch.from_batch(contrast_batch) clean_tgt_tokens = self.attribution_model.clean_tokens(target_tokens, as_targets=as_targets) @@ -358,6 +366,7 @@ def attribute( output_step_attributions: bool = False, attribute_target: bool = False, step_scores: list[str] = [], + skip_special_tokens: bool = False, attribution_args: dict[str, Any] = {}, attributed_fn_args: dict[str, Any] = {}, step_scores_args: dict[str, Any] = {}, @@ -385,6 +394,8 @@ def attribute( step_scores (:obj:`list` of `str`): List of identifiers for step scores that need to be computed during attribution. The available step scores are defined in :obj:`inseq.attr.feat.STEP_SCORES_MAP` and new step scores can be added by using the :meth:`~inseq.register_step_function` function. + skip_special_tokens (:obj:`bool`, `optional`): Whether to skip special tokens when encoding the input. + Defaults to False. attribution_args (:obj:`dict`, `optional`): Additional arguments to pass to the attribution method. attributed_fn_args (:obj:`dict`, `optional`): Additional arguments to pass to the attributed function. step_scores_args (:obj:`dict`, `optional`): Additional arguments to pass to the step scores function. @@ -419,6 +430,7 @@ def attribute( step_scores_args, attr_pos_start, attr_pos_end, + skip_special_tokens, ) target_tokens_with_ids = self.attribution_model.get_token_with_ids( batch, diff --git a/inseq/attr/step_functions.py b/inseq/attr/step_functions.py index 83aa8d6e..df82fef7 100644 --- a/inseq/attr/step_functions.py +++ b/inseq/attr/step_functions.py @@ -132,6 +132,7 @@ def contrast_logits_fn( contrast_targets: Optional[FeatureAttributionInput] = None, contrast_targets_alignments: Optional[list[list[tuple[int, int]]]] = None, contrast_force_inputs: bool = False, + skip_special_tokens: bool = False, ): """Returns the logit of a generation target given contrastive context or target prediction alternative. If only ``contrast_targets`` are specified, the logit of the contrastive prediction is computed given same @@ -144,6 +145,7 @@ def contrast_logits_fn( contrast_targets=contrast_targets, contrast_targets_alignments=contrast_targets_alignments, contrast_force_inputs=contrast_force_inputs, + skip_special_tokens=skip_special_tokens, ) return logit_fn(c_args) @@ -156,6 +158,7 @@ def contrast_prob_fn( contrast_targets_alignments: Optional[list[list[tuple[int, int]]]] = None, logprob: bool = False, contrast_force_inputs: bool = False, + skip_special_tokens: bool = False, ): """Returns the probability of a generation target given contrastive context or target prediction alternative. If only ``contrast_targets`` are specified, the probability of the contrastive prediction is computed given same @@ -168,6 +171,7 @@ def contrast_prob_fn( contrast_targets=contrast_targets, contrast_targets_alignments=contrast_targets_alignments, contrast_force_inputs=contrast_force_inputs, + skip_special_tokens=skip_special_tokens, ) return probability_fn(c_args, logprob=logprob) @@ -179,6 +183,7 @@ def pcxmi_fn( contrast_targets: Optional[FeatureAttributionInput] = None, contrast_targets_alignments: Optional[list[list[tuple[int, int]]]] = None, contrast_force_inputs: bool = False, + skip_special_tokens: bool = False, ) -> SingleScorePerStepTensor: """Compute the pointwise conditional cross-mutual information (P-CXMI) of target ids given original and contrastive input options. The P-CXMI is defined as the negative log-ratio between the conditional probability of the target @@ -192,6 +197,7 @@ def pcxmi_fn( contrast_targets=contrast_targets, contrast_targets_alignments=contrast_targets_alignments, contrast_force_inputs=contrast_force_inputs, + skip_special_tokens=skip_special_tokens, ).to(original_probs.device) return -torch.log2(torch.div(original_probs, contrast_probs)) @@ -206,6 +212,7 @@ def kl_divergence_fn( top_p: float = 1.0, min_tokens_to_keep: int = 1, contrast_force_inputs: bool = False, + skip_special_tokens: bool = False, ) -> SingleScorePerStepTensor: """Compute the pointwise Kullback-Leibler divergence of target ids given original and contrastive input options. The KL divergence is the expectation of the log difference between the probabilities of regular (P) and contrastive @@ -233,6 +240,7 @@ def kl_divergence_fn( contrast_targets_alignments=contrast_targets_alignments, return_contrastive_target_ids=False, return_contrastive_batch=True, + skip_special_tokens=skip_special_tokens, ) c_forward_output = args.attribution_model.get_forward_output( contrast_inputs.batch, use_embeddings=args.attribution_model.is_encoder_decoder @@ -263,6 +271,7 @@ def contrast_prob_diff_fn( contrast_targets_alignments: Optional[list[list[tuple[int, int]]]] = None, logprob: bool = False, contrast_force_inputs: bool = False, + skip_special_tokens: bool = False, ): """Returns the difference between next step probability for a candidate generation target vs. a contrastive alternative. Can be used as attribution target to answer the question: "Which features were salient in the @@ -279,6 +288,7 @@ def contrast_prob_diff_fn( contrast_targets_alignments=contrast_targets_alignments, logprob=logprob, contrast_force_inputs=contrast_force_inputs, + skip_special_tokens=skip_special_tokens, ).to(model_probs.device) return model_probs - contrast_probs @@ -290,6 +300,7 @@ def contrast_logits_diff_fn( contrast_targets: Optional[FeatureAttributionInput] = None, contrast_targets_alignments: Optional[list[list[tuple[int, int]]]] = None, contrast_force_inputs: bool = False, + skip_special_tokens: bool = False, ): """Equivalent to ``contrast_prob_diff_fn`` but for logits. The original target function used in `Yin and Neubig (2022) `__ @@ -301,6 +312,7 @@ def contrast_logits_diff_fn( contrast_targets=contrast_targets, contrast_targets_alignments=contrast_targets_alignments, contrast_force_inputs=contrast_force_inputs, + skip_special_tokens=skip_special_tokens, ).to(model_logits.device) return model_logits - contrast_logits @@ -312,6 +324,7 @@ def in_context_pvi_fn( contrast_targets: Optional[FeatureAttributionInput] = None, contrast_targets_alignments: Optional[list[list[tuple[int, int]]]] = None, contrast_force_inputs: bool = False, + skip_special_tokens: bool = False, ): """Returns the in-context pointwise V-usable information as defined by `Lu et al. (2023) `__. In-context PVI is a variant of P-CXMI that captures the amount of usable @@ -330,6 +343,7 @@ def in_context_pvi_fn( contrast_targets_alignments=contrast_targets_alignments, logprob=True, contrast_force_inputs=contrast_force_inputs, + skip_special_tokens=skip_special_tokens, ).to(orig_logprob.device) return -orig_logprob + contrast_logprob diff --git a/inseq/data/attribution.py b/inseq/data/attribution.py index 7841cf7c..b83da1df 100644 --- a/inseq/data/attribution.py +++ b/inseq/data/attribution.py @@ -47,6 +47,7 @@ def get_batch_from_inputs( inputs: FeatureAttributionInput, include_eos_baseline: bool = False, as_targets: bool = False, + skip_special_tokens: bool = False, ) -> Batch: if isinstance(inputs, Batch): batch = inputs @@ -57,6 +58,7 @@ def get_batch_from_inputs( as_targets=as_targets, return_baseline=True, include_eos_baseline=include_eos_baseline, + add_special_tokens=not skip_special_tokens, ) elif isinstance(inputs, BatchEncoding): encodings = inputs @@ -66,8 +68,12 @@ def get_batch_from_inputs( "Inputs must be either a string, a list of strings, a BatchEncoding or a Batch." ) embeddings = BatchEmbedding( - input_embeds=attribution_model.embed(encodings.input_ids, as_targets=as_targets), - baseline_embeds=attribution_model.embed(encodings.baseline_ids, as_targets=as_targets), + input_embeds=attribution_model.embed( + encodings.input_ids, as_targets=as_targets, add_special_tokens=not skip_special_tokens + ), + baseline_embeds=attribution_model.embed( + encodings.baseline_ids, as_targets=as_targets, add_special_tokens=not skip_special_tokens + ), ) batch = Batch(encodings, embeddings) return batch diff --git a/inseq/models/attribution_model.py b/inseq/models/attribution_model.py index 2f259a4c..42d72b3b 100644 --- a/inseq/models/attribution_model.py +++ b/inseq/models/attribution_model.py @@ -69,6 +69,7 @@ def prepare_inputs_for_attribution( attribution_model: "AttributionModel", inputs: FeatureAttributionInput, include_eos_baseline: bool = False, + skip_special_tokens: bool = False, ) -> Union[DecoderOnlyBatch, EncoderDecoderBatch]: raise NotImplementedError() @@ -316,6 +317,7 @@ def attribute( device: Optional[str] = None, batch_size: Optional[int] = None, generate_from_target_prefix: bool = False, + skip_special_tokens: bool = False, generation_args: dict[str, Any] = {}, **kwargs, ) -> FeatureAttributionOutput: @@ -365,6 +367,8 @@ def attribute( target prefixes for the generation process. If False, the ``generated_texts`` will be used as full targets. This option is only available for encoder-decoder models, since the same behavior can be achieved by modifying the input texts for decoder-only models. Default: False. + skip_special_tokens (:obj:`bool`, `optional`): Whether to skip special tokens when attributing the input + texts. Default: False. **kwargs: Additional keyword arguments. These can include keyword arguments for the attribution method, for the generation process or for the attributed function. Generation arguments can be provided explicitly as a dictionary named ``generation_args``. @@ -389,6 +393,8 @@ def attribute( self.device = device attribution_method = self.get_attribution_method(method, override_default_attribution) attributed_fn = self.get_attributed_fn(attributed_fn) + if skip_special_tokens: + kwargs["skip_special_tokens"] = True attribution_args, attributed_fn_args, step_scores_args = extract_args( attribution_method, attributed_fn, @@ -418,9 +424,16 @@ def attribute( logger.info(f"Splitting input texts into {n_batches} batches of size {batch_size}.") # If constrained decoding is not enabled, output texts are generated from input texts. if not has_generated_texts or generate_from_target_prefix: - encoded_input = self.encode(input_texts, return_baseline=True, include_eos_baseline=include_eos_baseline) + encoded_input = self.encode( + input_texts, + return_baseline=True, + include_eos_baseline=include_eos_baseline, + add_special_tokens=not skip_special_tokens, + ) if generate_from_target_prefix: - decoder_input = self.encode(generated_texts, as_targets=True) + decoder_input = self.encode( + generated_texts, as_targets=True, add_special_tokens=not skip_special_tokens + ) generation_args["decoder_input_ids"] = decoder_input.input_ids generated_texts = self.generate( encoded_input, return_generation_output=False, batch_size=batch_size, **generation_args @@ -467,6 +480,7 @@ def attribute( attribute_target=attribute_target, step_scores=step_scores, include_eos_baseline=include_eos_baseline, + skip_special_tokens=skip_special_tokens, attributed_fn=attributed_fn, attribution_args=attribution_args, attributed_fn_args=attributed_fn_args, @@ -484,11 +498,11 @@ def attribute( self.device = original_device return attribution_output - def embed(self, inputs: Union[TextInput, IdsTensor], as_targets: bool = False): + def embed(self, inputs: Union[TextInput, IdsTensor], as_targets: bool = False, add_special_tokens: bool = True): if isinstance(inputs, str) or ( isinstance(inputs, list) and len(inputs) > 0 and all(isinstance(x, str) for x in inputs) ): - batch = self.encode(inputs, as_targets) + batch = self.encode(inputs, as_targets, add_special_tokens=add_special_tokens) inputs = batch.input_ids return self.embed_ids(inputs, as_targets=as_targets) @@ -531,6 +545,7 @@ def encode( as_targets: bool = False, return_baseline: bool = False, include_eos_baseline: bool = False, + add_special_tokens: bool = True, ) -> BatchEncoding: pass diff --git a/inseq/models/decoder_only.py b/inseq/models/decoder_only.py index e9f8a25e..48aefce5 100644 --- a/inseq/models/decoder_only.py +++ b/inseq/models/decoder_only.py @@ -39,12 +39,14 @@ def prepare_inputs_for_attribution( attribution_model: "DecoderOnlyAttributionModel", inputs: FeatureAttributionInput, include_eos_baseline: bool = False, + skip_special_tokens: bool = False, ) -> DecoderOnlyBatch: batch = get_batch_from_inputs( attribution_model, inputs=inputs, include_eos_baseline=include_eos_baseline, as_targets=False, + skip_special_tokens=skip_special_tokens, ) return DecoderOnlyBatch.from_batch(batch) diff --git a/inseq/models/encoder_decoder.py b/inseq/models/encoder_decoder.py index 39fd7c65..81022be3 100644 --- a/inseq/models/encoder_decoder.py +++ b/inseq/models/encoder_decoder.py @@ -38,6 +38,7 @@ def prepare_inputs_for_attribution( attribution_model: "EncoderDecoderAttributionModel", inputs: tuple[FeatureAttributionInput, FeatureAttributionInput], include_eos_baseline: bool = False, + skip_special_tokens: bool = False, ) -> EncoderDecoderBatch: r"""Prepares sources and target to produce an :class:`~inseq.data.EncoderDecoderBatch`. There are two stages of preparation: @@ -67,12 +68,14 @@ def prepare_inputs_for_attribution( inputs=sources, include_eos_baseline=include_eos_baseline, as_targets=False, + skip_special_tokens=skip_special_tokens, ) target_batch = get_batch_from_inputs( attribution_model, inputs=targets, include_eos_baseline=include_eos_baseline, as_targets=True, + skip_special_tokens=skip_special_tokens, ) return EncoderDecoderBatch(source_batch, target_batch) diff --git a/inseq/models/huggingface_model.py b/inseq/models/huggingface_model.py index bb44f21c..eb374fbe 100644 --- a/inseq/models/huggingface_model.py +++ b/inseq/models/huggingface_model.py @@ -228,7 +228,7 @@ def generate( if isinstance(inputs, str) or ( isinstance(inputs, list) and len(inputs) > 0 and all(isinstance(x, str) for x in inputs) ): - inputs = self.encode(inputs) + inputs = self.encode(inputs, add_special_tokens=not skip_special_tokens) inputs = inputs.to(self.device) generation_out = self.model.generate( inputs=inputs.input_ids, diff --git a/inseq/utils/contrast_utils.py b/inseq/utils/contrast_utils.py index 5d8c00f4..576fef00 100644 --- a/inseq/utils/contrast_utils.py +++ b/inseq/utils/contrast_utils.py @@ -64,6 +64,7 @@ def _get_contrast_inputs( contrast_targets_alignments: Optional[list[list[tuple[int, int]]]] = None, return_contrastive_target_ids: bool = False, return_contrastive_batch: bool = False, + skip_special_tokens: bool = False, **forward_kwargs, ) -> ContrastInputs: """Utility function to return the output of the model for given contrastive inputs. @@ -81,6 +82,7 @@ def _get_contrast_inputs( attribution_model=args.attribution_model, inputs=contrast_targets, as_targets=is_enc_dec, + skip_special_tokens=skip_special_tokens, ) ).to(args.decoder_input_ids.device) curr_prefix_len = args.decoder_input_ids.size(1) @@ -107,7 +109,9 @@ def _get_contrast_inputs( "Contrastive source inputs can only be used with encoder-decoder models. " "Use `contrast_targets` to set a contrastive target containing a prefix for decoder-only models." ) - c_enc_in = args.attribution_model.encode(contrast_sources).to(args.encoder_input_ids.device) + c_enc_in = args.attribution_model.encode(contrast_sources, add_special_tokens=not skip_special_tokens).to( + args.encoder_input_ids.device + ) if ( args.encoder_input_ids.shape != c_enc_in.input_ids.shape or torch.ne(args.encoder_input_ids, c_enc_in.input_ids).any() @@ -128,6 +132,7 @@ def _setup_contrast_args( contrast_targets: Optional[FeatureAttributionInput] = None, contrast_targets_alignments: Optional[list[list[tuple[int, int]]]] = None, contrast_force_inputs: bool = False, + skip_special_tokens: bool = False, ): c_inputs = _get_contrast_inputs( args, @@ -136,6 +141,7 @@ def _setup_contrast_args( contrast_targets_alignments=contrast_targets_alignments, return_contrastive_target_ids=True, return_contrastive_batch=True, + skip_special_tokens=skip_special_tokens, ) if args.is_attributed_fn: if contrast_force_inputs: diff --git a/requirements-dev.txt b/requirements-dev.txt index 2fda0ec1..d086fc10 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -101,7 +101,7 @@ huggingface-hub==0.20.3 # transformers identify==2.5.34 # via pre-commit -idna==3.6 +idna==3.7 # via # requests # yarl @@ -199,7 +199,7 @@ pbr==6.0.0 # via stevedore pexpect==4.9.0 # via ipython -pillow==10.2.0 +pillow==10.3.0 # via matplotlib platformdirs==4.2.0 # via diff --git a/requirements.txt b/requirements.txt index 93809632..05e5b4c7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -24,7 +24,7 @@ huggingface-hub==0.20.3 # via # tokenizers # transformers -idna==3.6 +idna==3.7 # via requests jaxtyping==0.2.25 jinja2==3.1.3 @@ -55,7 +55,7 @@ packaging==23.2 # huggingface-hub # matplotlib # transformers -pillow==10.2.0 +pillow==10.3.0 # via matplotlib protobuf==4.25.2 # via transformers