Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/reagent branch for ReAGent #250

Merged
merged 17 commits into from
Apr 13, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions inseq/attr/feat/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .discretized_integrated_gradients import DiscretetizedIntegratedGradients
from .lime import Lime
from .monotonic_path_builder import MonotonicPathBuilder
from .reagent import Reagent
from .sequential_integrated_gradients import SequentialIntegratedGradients
from .value_zeroing import ValueZeroing

Expand All @@ -9,5 +10,6 @@
"MonotonicPathBuilder",
"ValueZeroing",
"Lime",
"Reagent",
"SequentialIntegratedGradients",
]
130 changes: 130 additions & 0 deletions inseq/attr/feat/ops/reagent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from typing import TYPE_CHECKING, Any, Union

import torch
from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric
from torch import Tensor
from typing_extensions import override

from ....utils.typing import InseqAttribution
from .reagent_core import (
AggregateRationalizer,
DeltaProbImportanceScoreEvaluator,
POSTagTokenSampler,
TopKStoppingConditionEvaluator,
UniformTokenReplacer,
)

if TYPE_CHECKING:
from ....models import HuggingfaceModel


class Reagent(InseqAttribution):
r"""Recursive attribution generator (ReAGent) method.

Measures importance as the drop in prediction probability produced by replacing a token with a plausible
alternative predicted by a LM.

Reference implementation:
`ReAGent: A Model-agnostic Feature Attribution Method for Generative Language Models
<https://arxiv.org/abs/2402.00794>`__

Args:
forward_func (callable): The forward function of the model or any modification of it
keep_top_n (int): If set to a value greater than 0, the top n tokens based on their importance score will be
kept during the prediction inference. If set to 0, the top n will be determined by ``keep_ratio``.
keep_ratio (float): If ``keep_top_n`` is set to 0, this specifies the proportion of tokens to keep.
invert_keep: If specified, the top tokens selected either via ``keep_top_n`` or ``keep_ratio`` will be
replaced instead of being kept.
stopping_condition_top_k (int): Threshold indicating that the stop condition achieved when the predicted target
exist in top k predictions
replacing_ratio (float): replacing ratio of tokens for probing
max_probe_steps (int): max_probe_steps
num_probes (int): number of probes in parallel

Example:
```
import inseq

model = inseq.load_model("gpt2-medium", "reagent",
keep_top_n=5,
stopping_condition_top_k=3,
replacing_ratio=0.3,
max_probe_steps=3000,
num_probes=8
)
out = model.attribute("Super Mario Land is a game that developed by")
out.show()
```
"""

def __init__(
self,
attribution_model: "HuggingfaceModel",
keep_top_n: int = 5,
keep_ratio: float = None,
invert_keep: bool = False,
stopping_condition_top_k: int = 3,
replacing_ratio: float = 0.3,
max_probe_steps: int = 3000,
num_probes: int = 16,
) -> None:
super().__init__(attribution_model)

model = attribution_model.model
tokenizer = attribution_model.tokenizer
model_name = attribution_model.model_name

sampler = POSTagTokenSampler(tokenizer=tokenizer, identifier=model_name, device=attribution_model.device)
stopping_condition_evaluator = TopKStoppingConditionEvaluator(
model=model,
sampler=sampler,
top_k=stopping_condition_top_k,
keep_top_n=keep_top_n,
keep_ratio=keep_ratio,
invert_keep=invert_keep,
)
importance_score_evaluator = DeltaProbImportanceScoreEvaluator(
model=model,
tokenizer=tokenizer,
token_replacer=UniformTokenReplacer(sampler=sampler, ratio=replacing_ratio),
stopping_condition_evaluator=stopping_condition_evaluator,
max_steps=max_probe_steps,
)

self.rationalizer = AggregateRationalizer(
importance_score_evaluator=importance_score_evaluator,
batch_size=num_probes,
overlap_threshold=0,
overlap_strict_pos=True,
keep_top_n=keep_top_n,
keep_ratio=keep_ratio,
)

@override
def attribute( # type: ignore
self,
inputs: TensorOrTupleOfTensorsGeneric,
_target: TargetType = None,
additional_forward_args: Any = None,
) -> Union[
TensorOrTupleOfTensorsGeneric,
tuple[TensorOrTupleOfTensorsGeneric, Tensor],
]:
"""Implement attribute"""
if len(additional_forward_args) == 8:
# encoder-decoder with target
self.rationalizer(additional_forward_args[0], additional_forward_args[2], additional_forward_args[1], True)

mean_important_score = torch.unsqueeze(self.rationalizer.mean_important_score, 0)
res = torch.unsqueeze(mean_important_score, 2).repeat(1, 1, inputs[0].shape[2])
return (res[:, : additional_forward_args[0].shape[1], :], res[:, additional_forward_args[0].shape[1] :, :])
elif len(additional_forward_args) == 9:
# encoder-decoder
self.rationalizer(additional_forward_args[1], additional_forward_args[3], additional_forward_args[2])
elif len(additional_forward_args) == 6:
# decoder only
self.rationalizer(additional_forward_args[0], additional_forward_args[1])

mean_important_score = torch.unsqueeze(self.rationalizer.mean_important_score, 0)
res = torch.unsqueeze(mean_important_score, 2).repeat(1, 1, inputs[0].shape[2])
return (res,)
13 changes: 13 additions & 0 deletions inseq/attr/feat/ops/reagent_core/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from .importance_score_evaluator import DeltaProbImportanceScoreEvaluator
from .rationalizer import AggregateRationalizer
from .stopping_condition_evaluator import TopKStoppingConditionEvaluator
from .token_replacer import UniformTokenReplacer
from .token_sampler import POSTagTokenSampler

__all__ = [
"DeltaProbImportanceScoreEvaluator",
"AggregateRationalizer",
"TopKStoppingConditionEvaluator",
"UniformTokenReplacer",
"POSTagTokenSampler",
]
220 changes: 220 additions & 0 deletions inseq/attr/feat/ops/reagent_core/importance_score_evaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
from __future__ import annotations

import logging
from abc import ABC, abstractmethod
from typing import Optional

import torch
from jaxtyping import Float
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
from typing_extensions import override

from .....utils.typing import IdsTensor, MultipleScoresPerStepTensor, TargetIdsTensor
from .stopping_condition_evaluator import StoppingConditionEvaluator
from .token_replacer import TokenReplacer


class BaseImportanceScoreEvaluator(ABC):
"""Importance Score Evaluator"""

def __init__(self, model: AutoModelForCausalLM | AutoModelForSeq2SeqLM, tokenizer: AutoTokenizer) -> None:
"""Base Constructor

Args:
model: A Huggingface AutoModelForCausalLM or AutoModelForSeq2SeqLM model
tokenizer: A Huggingface AutoTokenizer

"""

self.model = model
self.tokenizer = tokenizer

self.important_score = None

@abstractmethod
def __call__(
self,
input_ids: IdsTensor,
target_id: TargetIdsTensor,
decoder_input_ids: Optional[IdsTensor] = None,
attribute_target: bool = False,
) -> MultipleScoresPerStepTensor:
"""Evaluate importance score of input sequence

Args:
input_ids: input sequence [batch, sequence]
target_id: target token [batch]
decoder_input_ids (optional): decoder input sequence for AutoModelForSeq2SeqLM [batch, sequence]
attribute_target: whether attribute target for encoder-decoder models

Return:
importance_score: evaluated importance score for each token in the input [batch, sequence]

"""
raise NotImplementedError()


class DeltaProbImportanceScoreEvaluator(BaseImportanceScoreEvaluator):
"""Importance Score Evaluator"""

@override
def __init__(
self,
model: AutoModelForCausalLM | AutoModelForSeq2SeqLM,
tokenizer: AutoTokenizer,
token_replacer: TokenReplacer,
stopping_condition_evaluator: StoppingConditionEvaluator,
max_steps: float,
) -> None:
"""Constructor

Args:
model: A Huggingface AutoModelForCausalLM or AutoModelForSeq2SeqLM model
tokenizer: A Huggingface AutoTokenizer
token_replacer: A TokenReplacer
stopping_condition_evaluator: A StoppingConditionEvaluator

"""

super().__init__(model, tokenizer)

self.token_replacer = token_replacer
self.stopping_condition_evaluator = stopping_condition_evaluator
self.max_steps = max_steps

self.important_score = None
self.num_steps = 0

def update_importance_score(
self,
logit_importance_score: MultipleScoresPerStepTensor,
input_ids: IdsTensor,
target_id: TargetIdsTensor,
prob_original_target: Float[torch.Tensor, "batch_size 1"],
decoder_input_ids: Optional[IdsTensor] = None,
attribute_target: bool = False,
) -> MultipleScoresPerStepTensor:
"""Update importance score by one step

Args:
logit_importance_score: Current importance score in logistic scale [batch, sequence]
input_ids: input tensor [batch, sequence]
target_id: target tensor [batch]
prob_original_target: predictive probability of the target on the original sequence [batch, 1]
decoder_input_ids (optional): decoder input sequence for AutoModelForSeq2SeqLM [batch, sequence]
attribute_target: whether attribute target for encoder-decoder models

Return:
logit_importance_score: updated importance score in logistic scale [batch, sequence]

"""
# Randomly replace a set of tokens R to form a new sequence \hat{y_{1...t}}

if not attribute_target:
input_ids_replaced, mask_replacing = self.token_replacer(input_ids)
else:
ids_replaced, mask_replacing = self.token_replacer(torch.cat((input_ids, decoder_input_ids), 1))
input_ids_replaced = ids_replaced[:, : input_ids.shape[1]]
decoder_input_ids_replaced = ids_replaced[:, input_ids.shape[1] :]

logging.debug(f"Replacing mask: { mask_replacing }")
logging.debug(
f"Replaced sequence: { [[ self.tokenizer.decode(seq[i]) for i in range(input_ids_replaced.shape[1]) ] for seq in input_ids_replaced ] }"
)

# Inference \hat{p^{(y)}} = p(y_{t+1}|\hat{y_{1...t}})

if decoder_input_ids is None:
logits_replaced = self.model(input_ids_replaced)["logits"]
elif not attribute_target:
logits_replaced = self.model(input_ids=input_ids_replaced, decoder_input_ids=decoder_input_ids)["logits"]
else:
logits_replaced = self.model(input_ids=input_ids_replaced, decoder_input_ids=decoder_input_ids_replaced)[
"logits"
]

prob_replaced_target = torch.softmax(logits_replaced[:, -1, :], -1)[:, target_id]

# Compute changes delta = p^{(y)} - \hat{p^{(y)}}

delta_prob_target = prob_original_target - prob_replaced_target
logging.debug(f"likelihood delta: { delta_prob_target }")

# Update importance scores based on delta (magnitude) and replacement (direction)

delta_score = mask_replacing * delta_prob_target + ~mask_replacing * -delta_prob_target
# TODO: better solution?
# Rescaling from [-1, 1] to [0, 1] before logit function
logit_delta_score = torch.logit(delta_score * 0.5 + 0.5)
logit_importance_score = logit_importance_score + logit_delta_score
logging.debug(f"Updated importance score: { torch.softmax(logit_importance_score, -1) }")

return logit_importance_score

@override
def __call__(
self,
input_ids: IdsTensor,
target_id: TargetIdsTensor,
decoder_input_ids: Optional[IdsTensor] = None,
attribute_target: bool = False,
) -> MultipleScoresPerStepTensor:
"""Evaluate importance score of input sequence

Args:
input_ids: input sequence [batch, sequence]
target_id: target token [batch]
decoder_input_ids (optional): decoder input sequence for AutoModelForSeq2SeqLM [batch, sequence]
attribute_target: whether attribute target for encoder-decoder models

Return:
importance_score: evaluated importance score for each token in the input [batch, sequence]

"""

self.stop_mask = torch.zeros([input_ids.shape[0]], dtype=torch.bool, device=input_ids.device)

# Inference p^{(y)} = p(y_{t+1}|y_{1...t})
if decoder_input_ids is None:
logits_original = self.model(input_ids)["logits"]
else:
logits_original = self.model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)["logits"]

prob_original_target = torch.softmax(logits_original[:, -1, :], -1)[:, target_id]

# Initialize importance score s for each token in the sequence y_{1...t}

if not attribute_target:
logit_importance_score = torch.rand(input_ids.shape, device=input_ids.device)
else:
logit_importance_score = torch.rand(
(input_ids.shape[0], input_ids.shape[1] + decoder_input_ids.shape[1]), device=input_ids.device
)
logging.debug(f"Initialize importance score -> { torch.softmax(logit_importance_score, -1) }")

# TODO: limit max steps
self.num_steps = 0
while self.num_steps < self.max_steps:
self.num_steps += 1

# Update importance score
logit_importance_score_update = self.update_importance_score(
logit_importance_score, input_ids, target_id, prob_original_target, decoder_input_ids, attribute_target
)
logit_importance_score = (
~torch.unsqueeze(self.stop_mask, 1) * logit_importance_score_update
+ torch.unsqueeze(self.stop_mask, 1) * logit_importance_score
)

self.important_score = torch.softmax(logit_importance_score, -1)

# Evaluate stop condition
self.stop_mask = self.stop_mask | self.stopping_condition_evaluator(
input_ids, target_id, self.important_score, decoder_input_ids, attribute_target
)
if torch.prod(self.stop_mask) > 0:
break

logging.info(f"Importance score evaluated in {self.num_steps} steps.")

return torch.softmax(logit_importance_score, -1)
Loading
Loading