Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix multi-gpu inference and duplicate BOS issue for decoder-only #280

Merged
merged 4 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

## 🚀 Features

- Added new models `DbrxForCausalLM`, `OlmoForCausalLM`, `Phi3ForCausalLM`, `Qwen2MoeForCausalLM` to model config.
- Added new models `DbrxForCausalLM`, `OlmoForCausalLM`, `Phi3ForCausalLM`, `Qwen2MoeForCausalLM`, `Gemma2ForCausalLM` to model config.

- Add `rescale_attributions` to Inseq CLI commands for `rescale=True` ([#280](https://github.com/inseq-team/inseq/pull/280)).

- Rows and columns in the visualization now have indices alongside tokens to facilitate index-based slicing, aggregation and alignment [#282](https://github.com/inseq-team/inseq/pull/282)

Expand All @@ -23,7 +25,7 @@ out.save("output.json")
out.save("output_fp16.json", scores_precision="float16") # or "float8"

# Automatic conversion to float32
out_loaded = inseq.FeatureAttributionOutput.load("output_fp16.json")
out_loaded = inseq.FeatureAttributionOutput.load("output_fp16.json")
```

- - A new `SliceAggregator` (`"slices"`) is added to allow for slicing source (in encoder-decoder) or target (in decoder-only) tokens from a `FeatureAttributionSequenceOutput` object, using the same syntax of `ContiguousSpanAggregator`. The `__getitem__` method of the `FeatureAttributionSequenceOutput` is a shortcut for this, allowing slicing with `[start:stop]` syntax. [#282](https://github.com/inseq-team/inseq/pull/282)
Expand Down Expand Up @@ -71,18 +73,19 @@ out_female = attrib_model.attribute(
## 🔧 Fixes and Refactoring

- Fix the issue in the attention implementation from [#268](https://github.com/inseq-team/inseq/issues/268) where non-terminal position in the tensor were set to nan if they were 0s ([#269](https://github.com/inseq-team/inseq/pull/269)).

- Fix the pad token in cases where it is not specified by default in the loaded model (e.g. for Qwen models) ([#269](https://github.com/inseq-team/inseq/pull/269)).

- Fix bug reported in [#266](https://github.com/inseq-team/inseq/issues/266) making `value_zeroing` unusable for SDPA attention. This enables using the method on models using SDPA attention as default (e.g. `GemmaForCausalLM`) without passing `model_kwargs={'attn_implementation': 'eager'}` ([#267](https://github.com/inseq-team/inseq/pull/267)).

- The directions of generated/attributed tokens were clarified in the visualization using arrows instead of x/y [#282](https://github.com/inseq-team/inseq/pull/282)
- Fix multi-device support and duplicate BOS for chat template models ([#280](https://github.com/inseq-team/inseq/pull/280)).

- The directions of generated/attributed tokens were clarified in the visualization using arrows instead of x/y [#282](https://github.com/inseq-team/inseq/pull/282)

## 📝 Documentation and Tutorials

*No changes*

## 💥 Breaking Changes

*No changes*
*No changes*
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ Our vision for Inseq is to create a centralized, comprehensive and robust set of

## Citing Inseq

If you use Inseq in your research we suggest to include a mention to the specific release (e.g. v0.4.0) and we kindly ask you to cite our reference paper as:
If you use Inseq in your research we suggest to include a mention to the specific release (e.g. v0.6.0) and we kindly ask you to cite our reference paper as:

```bibtex
@inproceedings{sarti-etal-2023-inseq,
Expand Down Expand Up @@ -308,7 +308,7 @@ If you use Inseq in your research we suggest to include a mention to the specifi
Inseq has been used in various research projects. A list of known publications that use Inseq to conduct interpretability analyses of generative models is shown below.

> [!TIP]
> Last update: May 2024. Please open a pull request to add your publication to the list.
> Last update: June 2024. Please open a pull request to add your publication to the list.

<details>
<summary><b>2023</b></summary>
Expand All @@ -331,6 +331,9 @@ Inseq has been used in various research projects. A list of known publications t
<li><a href="https://arxiv.org/abs/2402.00794">ReAGent: A Model-agnostic Feature Attribution Method for Generative Language Models</a> (Zhao et al., 2024)</li>
<li><a href="https://arxiv.org/abs/2404.02421">Revisiting subword tokenization: A case study on affixal negation in large language models</a> (Truong et al., 2024)</li>
<li><a href="https://hal.science/hal-04581586">Exploring NMT Explainability for Translators Using NMT Visualising Tools</a> (Gonzalez-Saez et al., 2024)</li>
<li><a href="https://arxiv.org/abs/2405.14899">DETAIL: Task DEmonsTration Attribution for Interpretable In-context Learning</a> (Zhou et al., 2024)</li>
<li><a href="https://arxiv.org/abs/2406.06399">Should We Fine-Tune or RAG? Evaluating Different Techniques to Adapt LLMs for Dialogue</a> (Alghisi et al., 2024)</li>
<li><a href="https://arxiv.org/abs/2406.13663">Model Internals-based Answer Attribution for Trustworthy Retrieval-Augmented Generation</a> (Qi, Sarti et al., 2024)</li>
</ol>

</details>
3 changes: 3 additions & 0 deletions inseq/commands/attribute/attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@ def aggregate_attribution_scores(
selectors: Optional[list[int]] = None,
aggregators: Optional[list[str]] = None,
normalize_attributions: bool = False,
rescale_attributions: bool = False,
) -> FeatureAttributionOutput:
if selectors is not None and aggregators is not None:
for select_idx, aggregator_fn in zip(selectors, aggregators):
out = out.aggregate(
aggregator=aggregator_fn,
normalize=normalize_attributions,
rescale=rescale_attributions,
select_idx=select_idx,
do_post_aggregation_checks=False,
)
Expand Down Expand Up @@ -79,6 +81,7 @@ def attribute(input_texts, generated_texts, args: AttributeExtendedArgs):
selectors=args.attribution_selectors,
aggregators=args.attribution_aggregators,
normalize_attributions=args.normalize_attributions,
rescale_attributions=args.rescale_attributions,
)
print(f"Saving {'aggregated ' if args.aggregate_output else ''}attributions to {args.save_path}")
out.save(args.save_path, overwrite=True)
Expand Down
8 changes: 8 additions & 0 deletions inseq/commands/attribute/attribute_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,14 @@ class AttributeBaseArgs:
"for each context are normalized to sum up to 1, providing a relative notion of input salience."
),
)
rescale_attributions: bool = cli_arg(
default=False,
help=(
"Whether to rescale the attribution scores for each context. If ``True``, the attribution scores "
"for each context are rescaled to sum up to the number of tokens in the input, providing an absolute"
" notion of input salience."
),
)
model_kwargs: dict = cli_arg(
default_factory=dict,
help="Additional keyword arguments passed to the model constructor in JSON format.",
Expand Down
1 change: 1 addition & 0 deletions inseq/commands/attribute_context/attribute_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def attribute_context_with_model(args: AttributeContextArgs, model: HuggingfaceM
selectors=args.attribution_selectors,
aggregators=args.attribution_aggregators,
normalize_attributions=args.normalize_attributions,
rescale_attributions=args.rescale_attributions,
)[0]
if args.show_intermediate_outputs:
cci_attrib_out.show(do_aggregation=False)
Expand Down
4 changes: 3 additions & 1 deletion inseq/models/attribution_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,9 @@ def attribute(
"Step scores are not supported for final step methods since they do not iterate over the full"
" sequence. Please remove the step scores and compute them separatly passing method='dummy'."
)
input_texts, generated_texts = format_input_texts(input_texts, generated_texts)
input_texts, generated_texts = format_input_texts(
input_texts, generated_texts, skip_special_tokens, self.special_tokens
)
has_generated_texts = generated_texts is not None
if not self.is_encoder_decoder:
for i in range(len(input_texts)):
Expand Down
8 changes: 8 additions & 0 deletions inseq/models/huggingface_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,12 @@ def get_attentions_dict(
) -> dict[str, MultiLayerMultiUnitScoreTensor]:
if output.encoder_attentions is None or output.decoder_attentions is None:
raise ValueError("Model does not support attribution relying on attention outputs.")
if output.encoder_attentions is not None:
output.encoder_attentions = tuple(att.to("cpu") for att in output.encoder_attentions)
if output.decoder_attentions is not None:
output.decoder_attentions = tuple(att.to("cpu") for att in output.decoder_attentions)
if output.cross_attentions is not None:
output.cross_attentions = tuple(att.to("cpu") for att in output.cross_attentions)
return {
"encoder_self_attentions": torch.stack(output.encoder_attentions, dim=1),
"decoder_self_attentions": torch.stack(output.decoder_attentions, dim=1),
Expand Down Expand Up @@ -506,6 +512,8 @@ def configure_embeddings_scale(self):
def get_attentions_dict(output: CausalLMOutput) -> dict[str, MultiLayerMultiUnitScoreTensor]:
if output.attentions is None:
raise ValueError("Model does not support attribution relying on attention outputs.")
else:
output.attentions = tuple(att.to("cpu") for att in output.attentions)
return {
"decoder_self_attentions": torch.stack(output.attentions, dim=1),
}
Expand Down
3 changes: 3 additions & 0 deletions inseq/models/model_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ FalconForCausalLM:
GemmaForCausalLM:
self_attention_module: "self_attn"
value_vector: "value_states"
Gemma2ForCausalLM:
self_attention_module: "self_attn"
value_vector: "value_states"
GPTBigCodeForCausalLM:
self_attention_module: "attn"
value_vector: "value"
Expand Down
7 changes: 7 additions & 0 deletions inseq/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,8 @@ def isnotebook():
def format_input_texts(
texts: TextInput,
ref_texts: Optional[TextInput] = None,
skip_special_tokens: bool = False,
special_tokens: list[str] = [],
) -> tuple[list[str], list[str]]:
texts = [texts] if isinstance(texts, str) else texts
reference_texts = [ref_texts] if isinstance(ref_texts, str) else ref_texts
Expand All @@ -211,6 +213,11 @@ def format_input_texts(
len(texts), len(reference_texts)
)
)
if skip_special_tokens:
for special_token in special_tokens:
texts = [text.replace(special_token, "") for text in texts]
if reference_texts is not None:
reference_texts = [text.replace(special_token, "") for text in reference_texts]
return texts, reference_texts


Expand Down
Loading