diff --git a/inseq/__init__.py b/inseq/__init__.py index 697f70d..cc20a1c 100644 --- a/inseq/__init__.py +++ b/inseq/__init__.py @@ -8,6 +8,7 @@ merge_attributions, show_attributions, show_granular_attributions, + show_token_attributions, ) from .models import AttributionModel, list_supported_frameworks, load_model, register_model_config from .utils.id_utils import explain diff --git a/inseq/data/__init__.py b/inseq/data/__init__.py index bd5322b..99e20f5 100644 --- a/inseq/data/__init__.py +++ b/inseq/data/__init__.py @@ -31,7 +31,7 @@ EncoderDecoderBatch, slice_batch_from_position, ) -from .viz import show_attributions, show_granular_attributions +from .viz import show_attributions, show_granular_attributions, show_token_attributions __all__ = [ "Aggregator", @@ -59,6 +59,7 @@ "TextInput", "show_attributions", "show_granular_attributions", + "show_token_attributions", "list_aggregation_functions", "MultiDimensionalFeatureAttributionStepOutput", "get_batch_from_inputs", diff --git a/inseq/data/attribution.py b/inseq/data/attribution.py index d7532ec..6bbaca1 100644 --- a/inseq/data/attribution.py +++ b/inseq/data/attribution.py @@ -38,7 +38,7 @@ from .aggregator import AggregableMixin, Aggregator, AggregatorPipeline from .batch import Batch, BatchEmbedding, BatchEncoding, DecoderOnlyBatch, EncoderDecoderBatch from .data_utils import TensorWrapper -from .viz import get_saliency_heatmap_treescope +from .viz import get_saliency_heatmap_treescope, get_tokens_heatmap_treescope if TYPE_CHECKING: from ..models import AttributionModel @@ -235,12 +235,23 @@ def granular_attribution_visualizer( else: return treescope.IPythonVisualization( treescope.figures.inline( - adapter.get_array_summary(value, fast=False), - treescope.render_array( - value, - axis_labels={0: f"Generated Tokens: {value.shape[0]}"}, - axis_item_labels={0: column_labels}, + adapter.get_array_summary(value, fast=False) + "\n\n", + treescope.figures.figure_from_treescope_rendering_part( + treescope.rendering_parts.indented_children( + [ + get_tokens_heatmap_treescope( + tokens=column_labels, + scores=value.numpy(), + max_val=value.max().item(), + ) + ] + ) ), + # treescope.render_array( + # value, + # axis_labels={0: f"Generated Tokens: {value.shape[0]}"}, + # axis_item_labels={0: column_labels}, + # ), ), replace=True, ) @@ -431,6 +442,7 @@ def show( slice_dims: dict[int | str, tuple[int, int]] | None = None, display: bool = True, return_html: bool | None = False, + return_figure: bool = False, aggregator: AggregatorPipeline | type[Aggregator] = None, do_aggregation: bool = True, **kwargs, @@ -460,6 +472,8 @@ def show( for later use. return_html (:obj:`bool`, *optional*, defaults to False): Whether to return the HTML code of the visualization. + return_figure (:obj:`bool`, *optional*, defaults to False): + For granular visualization, whether to return the Treescope figure object for further manipulation. aggregator (:obj:`AggregatorPipeline`, *optional*, defaults to None): Aggregates attributions before visualizing them. If not specified, the default aggregator for the class is used. @@ -496,6 +510,7 @@ def show( show_dim=show_dim, display=display, return_html=return_html, + return_figure=return_figure, slice_dims=slice_dims, ) @@ -508,6 +523,7 @@ def show_granular( slice_dims: dict[int | str, tuple[int, int]] | None = None, display: bool = True, return_html: bool | None = False, + return_figure: bool = False, ) -> str | None: from inseq import show_granular_attributions @@ -520,6 +536,36 @@ def show_granular( slice_dims=slice_dims, display=display, return_html=return_html, + return_figure=return_figure, + ) + + def show_tokens( + self, + min_val: int | None = None, + max_val: int | None = None, + display: bool = True, + return_html: bool | None = False, + return_figure: bool = False, + replace_char: dict[str, str] | None = None, + wrap_after: int | str | list[str] | tuple[str] | None = None, + step_score_highlight: str | None = None, + aggregator: AggregatorPipeline | type[Aggregator] = None, + do_aggregation: bool = True, + **kwargs, + ) -> str | None: + from inseq import show_token_attributions + + aggregated = self.aggregate(aggregator, **kwargs) if do_aggregation else self + return show_token_attributions( + attributions=aggregated, + min_val=min_val, + max_val=max_val, + display=display, + return_html=return_html, + return_figure=return_figure, + replace_char=replace_char, + wrap_after=wrap_after, + step_score_highlight=step_score_highlight, ) @property @@ -859,10 +905,11 @@ def show( slice_dims: dict[int | str, tuple[int, int]] | None = None, display: bool = True, return_html: bool | None = False, + return_figure: bool = False, aggregator: AggregatorPipeline | type[Aggregator] = None, do_aggregation: bool = True, **kwargs, - ) -> str | None: + ) -> str | list | None: """Visualize the sequence attributions. Args: @@ -873,6 +920,7 @@ def show( slice_dims (dict[int or str, tuple[int, int]], optional): Dimensions to slice. display (bool, optional): If True, display the attribution visualization. return_html (bool, optional): If True, return the attribution visualization as HTML. + return_figure (bool, optional): If True, return the Treescope figure object for further manipulation. aggregator (:obj:`AggregatorPipeline` or :obj:`Type[Aggregator]`, optional): Aggregator or pipeline to use. If not provided, the default aggregator for every sequence attribution is used. @@ -881,11 +929,15 @@ def show( attributions are already aggregated. Returns: - str: Attribution visualization as HTML if `return_html=True`, None otherwise. + str: Attribution visualization as HTML if `return_html=True` + list: List of Treescope figure objects if `return_figure=True` + None if `return_html=False` and `return_figure=False` + """ out_str = "" + out_figs = [] for attr in self.sequence_attributions: - curr_out_str = attr.show( + curr_out = attr.show( min_val=min_val, max_val=max_val, max_show_size=max_show_size, @@ -893,14 +945,19 @@ def show( slice_dims=slice_dims, display=display, return_html=return_html, + return_figure=return_figure, aggregator=aggregator, do_aggregation=do_aggregation, **kwargs, ) if return_html: - out_str += curr_out_str + out_str += curr_out + if return_figure: + out_figs.append(curr_out) if return_html: return out_str + if return_figure: + return out_figs def show_granular( self, @@ -911,10 +968,12 @@ def show_granular( slice_dims: dict[int | str, tuple[int, int]] | None = None, display: bool = True, return_html: bool = False, + return_figure: bool = False, ) -> str | None: out_str = "" + out_figs = [] for attr in self.sequence_attributions: - curr_out_str = attr.show_granular( + curr_out = attr.show_granular( min_val=min_val, max_val=max_val, max_show_size=max_show_size, @@ -924,9 +983,52 @@ def show_granular( return_html=return_html, ) if return_html: - out_str += curr_out_str + out_str += curr_out + if return_figure: + out_figs.append(curr_out) + if return_html: + return out_str + if return_figure: + return out_figs + + def show_tokens( + self, + min_val: int | None = None, + max_val: int | None = None, + display: bool = True, + return_html: bool = False, + return_figure: bool = False, + replace_char: dict[str, str] | None = None, + wrap_after: int | str | list[str] | tuple[str] | None = None, + step_score_highlight: str | None = None, + aggregator: AggregatorPipeline | type[Aggregator] = None, + do_aggregation: bool = True, + **kwargs, + ) -> str | None: + out_str = "" + out_figs = [] + for attr in self.sequence_attributions: + curr_out = attr.show_tokens( + min_val=min_val, + max_val=max_val, + display=display, + return_html=return_html, + return_figure=return_figure, + replace_char=replace_char, + wrap_after=wrap_after, + step_score_highlight=step_score_highlight, + aggregator=aggregator, + do_aggregation=do_aggregation, + **kwargs, + ) + if return_html: + out_str += curr_out + if return_figure: + out_figs.append(curr_out) if return_html: return out_str + if return_figure: + return out_figs def weight_attributions(self, step_score_id: str): for i, attr in enumerate(self.sequence_attributions): diff --git a/inseq/data/viz.py b/inseq/data/viz.py index 0a56fb3..30078f0 100644 --- a/inseq/data/viz.py +++ b/inseq/data/viz.py @@ -21,7 +21,9 @@ from typing import TYPE_CHECKING, Literal import numpy as np -import treescope +import treescope as ts +import treescope.figures as fg +import treescope.rendering_parts as rp from matplotlib.colors import Colormap from rich import box from rich.color import Color @@ -37,11 +39,13 @@ from tqdm.std import tqdm from ..utils import isnotebook +from ..utils.misc import clean_tokens from ..utils.typing import TextSequences from ..utils.viz_utils import ( final_plot_html, get_colors, get_instance_html, + maybe_add_linebreak, red_transparent_blue_colormap, saliency_heatmap_html, saliency_heatmap_table_header, @@ -55,9 +59,9 @@ if isnotebook(): cmap = treescope_cmap() - treescope.basic_interactive_setup(autovisualize_arrays=True) - treescope.default_diverging_colormap.set_globally(cmap) - treescope.default_sequential_colormap.set_globally(cmap) + ts.basic_interactive_setup(autovisualize_arrays=True) + ts.default_diverging_colormap.set_globally(cmap) + ts.default_sequential_colormap.set_globally(cmap) def show_attributions( @@ -187,7 +191,7 @@ def show_granular_attributions( for ex_id, attribution in enumerate(attributions): if attribution.source_attributions is not None: items_to_render += [ - treescope.figures.bolded(f"Example {ex_id}: Source Saliency Heatmap"), + fg.bolded("Source Saliency Heatmap"), get_saliency_heatmap_treescope( attribution.source_attributions.numpy(), [t.token for t in attribution.target[attribution.attr_pos_start : attribution.attr_pos_end]], @@ -202,7 +206,7 @@ def show_granular_attributions( ] if attribution.target_attributions is not None: items_to_render += [ - treescope.figures.bolded(f"Example {ex_id}: Target Saliency Heatmap"), + fg.bolded("Target Saliency Heatmap"), get_saliency_heatmap_treescope( attribution.target_attributions.numpy(), [t.token for t in attribution.target[attribution.attr_pos_start : attribution.attr_pos_end]], @@ -216,45 +220,144 @@ def show_granular_attributions( ), ] items_to_render.append("") - fig = treescope.figures.inline(*items_to_render) + fig = fg.inline(*items_to_render) if return_figure: return fig if display: - treescope.show(fig) + ts.show(fig) if return_html: - return treescope.render_to_html(fig) + return ts.render_to_html(fig) def show_token_attributions( attributions: "FeatureAttributionSequenceOutput", + min_val: int | None = None, + max_val: int | None = None, display: bool = True, return_html: bool | None = False, return_figure: bool = False, replace_char: dict[str, str] | None = None, - wrap_after: int | str | None = None, + wrap_after: int | str | list[str] | tuple[str] | None = None, + step_score_highlight: str | None = None, ): - # from inseq.data.attribution import FeatureAttributionSequenceOutput - # - # if isinstance(attributions, FeatureAttributionSequenceOutput): - # attributions: list["FeatureAttributionSequenceOutput"] = [attributions] - # if not isnotebook() and display: - # raise ValueError( - # "Token attribution heatmaps visualization is only supported in Jupyter notebooks. " - # "Please set `display=False` and `return_html=True` to avoid this error." - # ) - # if return_html and return_figure: - # raise ValueError("Only one of `return_html` and `return_figure` can be set to True.") - # if replace_char is None: - # replace_char = {"Ġ": " ", "▁": " ", "Ċ": ""} - # items_to_render = [] - # for attr_idx, attr in enumerate(attributions): - # cleaned_tokens = [] - # for t in attr.target: - # curr_tok = t.token - # for k, v in replace_char.items(): - # curr_tok = curr_tok.replace(k, v) - # cleaned_tokens.append(curr_tok) - pass + """Visualizes token-level attributions in HTML format. + + Args: + attributions (:class:`~inseq.data.attribution.FeatureAttributionSequenceOutput`): + Sequence attributions to be visualized. + min_val (:obj:`Optional[int]`, *optional*, defaults to None): + Lower attribution score threshold for color map. + max_val (`Optional[int]`, *optional*, defaults to None): + Upper attribution score threshold for color map. + display (`bool`, *optional*, defaults to True): + Whether to show the output of the visualization function. + return_html (`Optional[bool]`, *optional*, defaults to False): + If true, returns the HTML corresponding to the notebook visualization of the attributions in string format, + for saving purposes. + return_figure (`Optional[bool]`, *optional*, defaults to False): + If true, returns the Treescope figure object for further manipulation. + replace_char (`Optional[dict[str, str]]`, *optional*, defaults to None): + Dictionary mapping strings to be replaced to replacement options, used for cleaning special characters. + Default: {}. + wrap_after (`Optional[int | str | list[str] | tuple[str]]`, *optional*, defaults to None): + Token indices or tokens after which to wrap lines. E.g. 10 = wrap after every 10 tokens, "hi" = wrap after + word hi occurs, ["." "!", "?"] or ".!?" = wrap after every sentence-ending punctuation. + step_score_highlight (`Optional[str]`, *optional*, defaults to None): + Name of the step score to use to highlight generated tokens in the visualization. If None, no highlights are + shown. Default: None. + """ + from inseq.data.attribution import FeatureAttributionSequenceOutput + + if isinstance(attributions, FeatureAttributionSequenceOutput): + attributions: list["FeatureAttributionSequenceOutput"] = [attributions] + if not isnotebook() and display: + raise ValueError( + "Token attribution visualization is only supported in Jupyter notebooks. " + "Please set `display=False` and `return_html=True` to avoid this error." + ) + if return_html and return_figure: + raise ValueError("Only one of `return_html` and `return_figure` can be set to True.") + if replace_char is None: + replace_char = {} + if max_val is None: + max_val = max(attribution.maximum for attribution in attributions) + if step_score_highlight is not None and ( + attributions[0].step_scores is None or step_score_highlight not in attributions[0].step_scores + ): + raise ValueError( + f'The requested step score "{step_score_highlight}" is not available for highlights in the provided ' + "attribution object. Please set `step_score_highlight=None` or recompute `model.attribute` by passing " + f'`step_scores=["{step_score_highlight}"].' + ) + generated_token_parts = [] + for attr in attributions: + cleaned_generated_tokens = clean_tokens(t.token for t in attr.target[attr.attr_pos_start : attr.attr_pos_end]) + cleaned_input_tokens = clean_tokens(t.token for t in attr.source) + cleaned_target_tokens = clean_tokens(t.token for t in attr.target) + step_scores = None + title = "Generated text:\n\n" + if step_score_highlight is not None: + step_scores = attr.step_scores[step_score_highlight] + scores_vmin = step_scores.min().item() + scores_vmax = step_scores.max().item() + title = f"Generated text with {step_score_highlight} highlights:\n\n" + generated_token_parts.append(rp.custom_style(rp.text(title), css_style="font-weight: bold;")) + for gen_idx, curr_gen_tok in enumerate(cleaned_generated_tokens): + attributed_token_parts = [rp.text("\n")] + if attr.source_attributions is not None: + attributed_token_parts.append( + get_tokens_heatmap_treescope( + tokens=cleaned_input_tokens, + scores=attr.source_attributions[:, gen_idx].numpy(), + title=f'Source attributions for "{curr_gen_tok}"', + title_style="font-style: italic; color: #888888;", + min_val=min_val, + max_val=max_val, + wrap_after=wrap_after, + ) + ) + attributed_token_parts.append(rp.text("\n\n")) + if attr.target_attributions is not None: + attributed_token_parts.append( + get_tokens_heatmap_treescope( + tokens=cleaned_target_tokens[: attr.attr_pos_start + gen_idx], + scores=attr.target_attributions[:, gen_idx].numpy(), + title=f'Target attributions for "{curr_gen_tok}"', + title_style="font-style: italic; color: #888888;", + min_val=min_val, + max_val=max_val, + wrap_after=wrap_after, + ) + ) + attributed_token_parts.append(rp.text("\n\n")) + if step_scores is not None: + gen_tok_label = fg.treescope_part_from_display_object( + fg.text_on_color( + curr_gen_tok, value=round(step_scores[gen_idx].item(), 4), vmin=scores_vmin, vmax=scores_vmax + ) + ) + else: + gen_tok_label = rp.text(curr_gen_tok) + generated_token_parts.append( + rp.build_full_line_with_annotations( + rp.build_custom_foldable_tree_node( + label=gen_tok_label, + contents=rp.fold_condition( + collapsed=rp.text(" "), + expanded=rp.indented_children([rp.siblings(*attributed_token_parts)]), + ), + ) + ) + ) + fig = fg.figure_from_treescope_rendering_part( + rp.custom_style(rp.siblings(*generated_token_parts), css_style="white-space: pre-wrap") + ) + if return_figure: + return fig + if display: + ts.show(fig) + if return_html: + return ts.render_to_html(fig) def get_attribution_colors( @@ -467,7 +570,7 @@ def get_saliency_heatmap_treescope( else: slider_dims.append(dim_idx) item_labels_dict[dim_idx] = [f"{dim_name} #{i}" for i in range(scores.shape[dim_idx])] - return treescope.render_array( + return ts.render_array( scores, rows=[0], columns=col_dims, @@ -479,8 +582,32 @@ def get_saliency_heatmap_treescope( ) -def get_tokens_heatmap_treescope(): - pass +def get_tokens_heatmap_treescope( + tokens: list[str], + scores: np.ndarray, + title: str | None = None, + title_style: str | None = None, + min_val: float | None = None, + max_val: float | None = None, + wrap_after: int | str | list[str] | tuple[str] | None = None, +): + parts = [] + if title is not None: + parts.append( + rp.custom_style( + rp.text(title + ":\n"), + css_style=title_style, + ) + ) + for idx, tok in enumerate(tokens): + if not np.isnan(scores[idx]): + parts.append( + fg.treescope_part_from_display_object( + fg.text_on_color(tok, value=round(scores[idx], 4), vmin=min_val, vmax=max_val) + ) + ) + parts += maybe_add_linebreak(tok, idx, wrap_after) + return rp.siblings(*parts) # Progress bar utilities diff --git a/inseq/utils/alignment_utils.py b/inseq/utils/alignment_utils.py index 0864095..daede31 100644 --- a/inseq/utils/alignment_utils.py +++ b/inseq/utils/alignment_utils.py @@ -227,8 +227,8 @@ def auto_align_sequences( a_to_b_word_align = compute_word_aligns(a_words, b_words) # 2. Align word-level alignments to token-level alignments from the generative model tokenizer. # Requires cleaning up the model tokens from special tokens (special characters already removed) - clean_a_tokens, removed_a_token_idxs = clean_tokens(a_tokens, filter_special_tokens) - clean_b_tokens, removed_b_token_idxs = clean_tokens(b_tokens, filter_special_tokens) + clean_a_tokens, removed_a_token_idxs = clean_tokens(a_tokens, filter_special_tokens, return_removed_idxs=True) + clean_b_tokens, removed_b_token_idxs = clean_tokens(b_tokens, filter_special_tokens, return_removed_idxs=True) if len(removed_a_token_idxs) != len(removed_b_token_idxs): logger.debug( "The number of special tokens in the target and contrast sequences do not match. " diff --git a/inseq/utils/misc.py b/inseq/utils/misc.py index 1f49d14..5843f07 100644 --- a/inseq/utils/misc.py +++ b/inseq/utils/misc.py @@ -431,16 +431,27 @@ def get_cls_from_instance_type(mod, name, cls_lookup_map): return curr_class -def clean_tokens(tokens: list[str], remove_tokens: list[str]) -> tuple[list[str], list[int]]: +def clean_tokens( + tokens: list[str], + remove_tokens: list[str] = [], + return_removed_idxs: bool = False, + replace_chars: dict[str, str] | None = None, +) -> list[str] | tuple[list[str], list[int]]: """Removes tokens from a list of tokens and returns the cleaned list and the removed token indexes.""" clean_tokens = [] removed_token_idxs = [] for idx, tok in enumerate(tokens): - if tok not in remove_tokens: - clean_tokens += [tok.strip()] - else: + new_tok = tok + if new_tok in remove_tokens: removed_token_idxs += [idx] - return clean_tokens, removed_token_idxs + else: + if replace_chars is not None: + for k, v in replace_chars.items(): + new_tok = new_tok.replace(k, v) + clean_tokens += [new_tok.strip()] + if return_removed_idxs: + return clean_tokens, removed_token_idxs + return clean_tokens def get_left_padding(text: str): diff --git a/inseq/utils/viz_utils.py b/inseq/utils/viz_utils.py index ed7f7e3..b0672ee 100644 --- a/inseq/utils/viz_utils.py +++ b/inseq/utils/viz_utils.py @@ -19,6 +19,7 @@ import matplotlib.pyplot as plt import numpy as np +import treescope as ts from matplotlib.colors import Colormap, LinearSegmentedColormap from numpy.typing import NDArray @@ -127,6 +128,17 @@ def test_dim(dim: int | str, dim_names: dict[int, str], rev_dim_names: dict[str, return dim_idx +def maybe_add_linebreak(tok: str, i: int, wrap_after: int | str | list[str] | tuple[str]) -> list[str]: + if isinstance(wrap_after, str) and tok == wrap_after: + return [ts.rendering_parts.text("\n")] + elif isinstance(wrap_after, list | tuple) and tok in wrap_after: + return [ts.rendering_parts.text("\n")] + elif isinstance(wrap_after, int) and i % wrap_after == 0: + return [ts.rendering_parts.text("\n")] + else: + return [] + + # Full plot final_plot_html = """