-
Notifications
You must be signed in to change notification settings - Fork 63
/
Copy pathrepr_tools.py
161 lines (137 loc) · 4.99 KB
/
repr_tools.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
"""
Contains utilities for extracting token representations and indices
from string templates. Used in computing the left and right vectors for ROME.
"""
from copy import deepcopy
from typing import List
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from util import nethook
def get_reprs_at_word_tokens(
model: AutoModelForCausalLM,
tok: AutoTokenizer,
context_templates: List[str],
words: List[str],
layer: int,
module_template: str,
subtoken: str,
track: str = "in",
) -> torch.Tensor:
"""
Retrieves the last token representation of `word` in `context_template`
when `word` is substituted into `context_template`. See `get_last_word_idx_in_template`
for more details.
"""
idxs = get_words_idxs_in_templates(tok, context_templates, words, subtoken)
return get_reprs_at_idxs(
model,
tok,
[context_templates[i].format(words[i]) for i in range(len(words))],
idxs,
layer,
module_template,
track,
)
def get_words_idxs_in_templates(
tok: AutoTokenizer, context_templates: str, words: str, subtoken: str
) -> int:
"""
Given list of template strings, each with *one* format specifier
(e.g. "{} plays basketball"), and words to be substituted into the
template, computes the post-tokenization index of their last tokens.
"""
assert all(
tmp.count("{}") == 1 for tmp in context_templates
), "We currently do not support multiple fill-ins for context"
# Compute prefixes and suffixes of the tokenized context
fill_idxs = [tmp.index("{}") for tmp in context_templates]
prefixes, suffixes = [
tmp[: fill_idxs[i]] for i, tmp in enumerate(context_templates)
], [tmp[fill_idxs[i] + 2 :] for i, tmp in enumerate(context_templates)]
words = deepcopy(words)
# Pre-process tokens
for i, prefix in enumerate(prefixes):
if len(prefix) > 0:
assert prefix[-1] == " "
prefix = prefix[:-1]
prefixes[i] = prefix
words[i] = f" {words[i].strip()}"
# Tokenize to determine lengths
assert len(prefixes) == len(words) == len(suffixes)
n = len(prefixes)
batch_tok = tok([*prefixes, *words, *suffixes])
prefixes_tok, words_tok, suffixes_tok = [
batch_tok[i : i + n] for i in range(0, n * 3, n)
]
prefixes_len, words_len, suffixes_len = [
[len(el) for el in tok_list]
for tok_list in [prefixes_tok, words_tok, suffixes_tok]
]
# Compute indices of last tokens
if subtoken == "last" or subtoken == "first_after_last":
return [
[
prefixes_len[i]
+ words_len[i]
- (1 if subtoken == "last" or suffixes_len[i] == 0 else 0)
]
# If suffix is empty, there is no "first token after the last".
# So, just return the last token of the word.
for i in range(n)
]
elif subtoken == "first":
return [[prefixes_len[i]] for i in range(n)]
else:
raise ValueError(f"Unknown subtoken type: {subtoken}")
def get_reprs_at_idxs(
model: AutoModelForCausalLM,
tok: AutoTokenizer,
contexts: List[str],
idxs: List[List[int]],
layer: int,
module_template: str,
track: str = "in",
) -> torch.Tensor:
"""
Runs input through model and returns averaged representations of the tokens
at each index in `idxs`.
"""
def _batch(n):
for i in range(0, len(contexts), n):
yield contexts[i : i + n], idxs[i : i + n]
assert track in {"in", "out", "both"}
both = track == "both"
tin, tout = (
(track == "in" or both),
(track == "out" or both),
)
module_name = module_template.format(layer)
to_return = {"in": [], "out": []}
def _process(cur_repr, batch_idxs, key):
nonlocal to_return
if len(cur_repr)==0:
return
cur_repr = cur_repr[0] if type(cur_repr) is tuple else cur_repr
for i, idx_list in enumerate(batch_idxs):
to_return[key].append(cur_repr[i][idx_list].mean(0))
for batch_contexts, batch_idxs in _batch(n=128):
contexts_tok = tok(batch_contexts, padding=True, return_tensors="pt").to(
next(model.parameters()).device
)
with torch.no_grad():
with nethook.Trace(
module=model,
layer=module_name,
retain_input=tin,
retain_output=tout,
) as tr:
model(**contexts_tok)
if tin:
_process(tr.input, batch_idxs, "in")
if tout:
_process(tr.output, batch_idxs, "out")
to_return = {k: torch.stack(v, 0) for k, v in to_return.items() if len(v) > 0}
if len(to_return) == 1:
single_key = list(to_return.keys())[0]
dummy_tensor = torch.zeros_like(to_return[single_key], device="cuda")
return (to_return["in"], dummy_tensor) if (single_key=="in" and tin) else (dummy_tensor, to_return["out"])