-
Notifications
You must be signed in to change notification settings - Fork 69
/
Copy pathapp_hierarchical.py
346 lines (304 loc) · 14.7 KB
/
app_hierarchical.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
import streamlit as st
import os
import torch
import pandas as pd
import transformers
from accelerate import Accelerator
from functools import partial
from lime_tools.explainers import (
LocalExplanationHierarchical,
LocalExplanationSentenceEmbedder,
LocalExplanationLikelihood
)
from lime_tools.text_utils import (
split_into_sentences,
extract_non_overlapping_ngrams
)
from model_lib.openai_tooling import (
CompletionsOpenAI
)
from model_lib.hf_tooling import HF_LM
from model_lib.streamlit_frontend import css, explanation_text, start_frontend, ProgressBar
start_frontend()
dataset_folder = "./data/"
default_instructions = """The following is a conversation between a highly knowledgeable and intelligent AI assistant,
called Falcon, and a human user, called User. In the following interactions,
User and Falcon will converse in natural language, and Falcon will answer User's questions.
Falcon was built to be respectful, polite and inclusive.
Falcon was built by the Technology Innovation Institute in Abu Dhabi.
Falcon will never decline to answer a question,
and always attempts to give an answer that User would be satisfied with.
It knows a lot, and always tells the truth. The conversation begins.
"""
dummy_text = """
Following the death of Pope Pius XII on 9 October 1958,
Roncalli watched the live funeral on his last full day in Venice on 11 October.
His journal was specifically concerned with the funeral and the abused state of the late pontiff's corpse.
Roncalli left Venice for the conclave in Rome well aware that he was papabile,[b] and after eleven ballots,
was elected to succeed the late Pius XII, so it came as no surprise to him,
though he had arrived at the Vatican with a return train ticket to Venice.
[citation needed] Many had considered Giovanni Battista Montini, the Archbishop of Milan,
a possible candidate, but, although he was the archbishop of one of the most ancient and prominent sees in Italy,
he had not yet been made a cardinal. Though his absence from the 1958 conclave did not make him ineligible –
under Canon Law any Catholic male who is capable of receiving priestly ordination
and episcopal consecration may be elected –
the College of Cardinals usually chose the new pontiff from among the Cardinals who head archdioceses or departments
of the Roman Curia that attend the papal conclave. At the time, as opposed to contemporary practice,
the participating Cardinals did not have to be below age 80 to vote, there were few Eastern-rite Cardinals,
and no Cardinals who were just priests at the time of their elevation.
Roncalli was summoned to the final ballot of the conclave at 4:00 pm.
He was elected pope at 4:30 pm with a total of 38 votes. After the long pontificate of Pope Pius XII,
the cardinals chose a man who – it was presumed because of his advanced age –
would be a short-term or "stop-gap" pope. They wished to choose a candidate who would
do little during the new pontificate.
Upon his election, Cardinal Eugene Tisserant asked him the ritual questions of whether he would accept and if so,
what name he would take for himself. Roncalli gave the first of his many
surprises when he chose "John" as his regnal name.
Roncalli's exact words were "I will be called John".
This was the first time in over 500 years that this name had been chosen;
previous popes had avoided its use since the time of the Antipope John XXIII
during the Western Schism several centuries before.
Far from being a mere "stopgap" pope, to great excitement,
John XXIII called for an ecumenical council fewer than ninety years after the First Vatican Council
(Vatican I's predecessor, the Council of Trent, had been held in the 16th century).
This decision was announced on 29 January 1959 at the Basilica of Saint Paul Outside the Walls.
Cardinal Giovanni Battista Montini, who later became Pope Paul VI, remarked to Giulio Bevilacqua that
"this holy old boy doesn't realise what a hornet's nest he's stirring up".
From the Second Vatican Council came changes that reshaped the face of Catholicism:
a comprehensively revised liturgy, a stronger emphasis on ecumenism, and a new approach to the world.
John XXIII was an advocate for human rights which included the unborn and the elderly.
He wrote about human rights in his Pacem in terris. He wrote, "Man has the right to live.
He has the right to bodily integrity and to the means necessary for the proper development of life
, particularly food, clothing, shelter, medical care, rest, and, finally, the necessary social services.
In consequence, he has the right to be looked after in the event of ill health; disability stemming from his work;
widowhood; old age; enforced unemployment; or whenever through no fault of his
own he is deprived of the means of livelihood."
Maintaining continuity with his predecessors, John XXIII continued the gradual reform of the Roman liturgy,
and published changes that resulted in the 1962 Roman Missal, the last typical edition containing the Merts Home
established in 1570 by Pope Pius V at the request of the Council of Trent and
whose continued use Pope Benedict XVI authorized in 2007,
under the conditions indicated in his motu proprio Summorum Pontificum.
In response to the directives of the Second Vatican Council,
later editions of the Roman Missal present the 1970 form of the Roman Rite.
Please answer the below question based only on the above passage.
Question: What did Pope Pius V establish in 1570?"""
dummy_text = dummy_text.strip()
@st.cache_resource(max_entries=1)
def load_model(model_name: str, device="cuda"):
print("Loading model...")
if "openai" in model_name:
model_wrapped = CompletionsOpenAI(
engine=model_name.replace(
"openai-", ""), format_fn=None)
else:
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
model = transformers.AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto")
model_wrapped = HF_LM(model, tokenizer, device=device, format_fn=None)
return model_wrapped
@st.cache_resource(max_entries=1)
def load_sentence_embedder():
print("Loading model...")
return INSTRUCTOR('hkunlp/instructor-xl')
def format_chat_prompt(
message: str,
instructions=default_instructions,
bot_name="Falcon",
user_name="User") -> str:
instructions = instructions.strip(" ").strip("\n")
prompt = instructions
prompt = f"{prompt}\n{user_name}: {message}\n{bot_name}:"
return prompt
datasets = [f for f in os.listdir(dataset_folder) if f.endswith('.csv')]
if "dataset_name" not in st.session_state:
st.session_state["dataset_name"] = datasets[0]
data = pd.read_csv(f'{dataset_folder}/{datasets[0]}', header=0)
st.session_state["dataset"] = data
with st.sidebar:
with st.form(key='my_form'):
selected_dataset = st.selectbox('Select a dataset', datasets, index=0)
if st.form_submit_button(label='Load dataset'):
# Load the selected dataset
data = pd.read_csv(
f'{dataset_folder}/{selected_dataset}', header=0)
st.session_state["index"] = 0
st.session_state["dataset_name"] = selected_dataset
st.session_state["dataset"] = data
data = st.session_state["dataset"]
accelerator = Accelerator()
model_pretty = {
"Falcon-7B Instruct": "tiiuae/falcon-7b-instruct",
"Falcon-40B Instruct": "tiiuae/falcon-40b-instruct",
"openai-Davinci002": "openai-mert-DaVinci002",
"openai-Davinci003": "openai-Davinci003",
}
# Render the CSS style
st.markdown(css, unsafe_allow_html=True)
with st.sidebar:
model_name = st.radio(
"Choose a model to explain",
options=list(model_pretty.keys()),
index=0
)
step_explainer_name = st.radio(
"Choose an explainer for each step in the hierarchy", options=[
"Likelihood", "Sentence Embedder"], index=0)
if step_explainer_name == "Likelihood":
step_explainer_class = LocalExplanationLikelihood
sentence_embedder = None
else:
sentence_embedder = load_sentence_embedder()
step_explainer_class = partial(
LocalExplanationSentenceEmbedder,
sentence_embedder=sentence_embedder)
mode = st.radio("Mode", options=["Dev", "Present"], index=0)
if mode == "Dev":
# ngram_size = st.slider("NGram size", min_value=3, max_value=10, value=5, step=1)
max_parts_to_perturb = st.slider(
"Maximum parts to perturb jointly",
min_value=1,
max_value=5,
value=2,
step=1)
num_perturbations = st.slider(
"Number of perturbations",
min_value=1,
max_value=500,
value=40,
step=50)
max_features = st.slider(
"Max features in the explanation",
min_value=2,
max_value=10,
value=2,
step=1)
max_completion = st.slider(
"Maximum tokens in the completion",
min_value=5,
max_value=250,
value=50,
step=5)
else:
ngram_size = 5
max_parts_to_perturb = 2
num_perturbations = 100
max_features = 4
max_completion = 5
st.session_state["model_name"] = model_name
model_wrapped = load_model(model_name=model_pretty[model_name])
st.title('Hierarchical Explanations for Autoregressive LLMs')
# setting = st.radio("Setting", options=["Passage+QA", "Full"], index=1, horizontal=True)
# A couple of hparams to collect
if 'index' not in st.session_state:
st.session_state['index'] = 0
cols = st.columns(2)
with cols[0]:
if st.button('Previous Row'):
st.session_state['index'] -= 1 # Ensure it doesn't go below 0
st.session_state['index'] = max(0, st.session_state['index'])
with cols[1]:
if st.button('Next Row'):
# Ensure it doesn't exceed the dataframe length
st.session_state['index'] += 1
st.session_state['index'] = min(
st.session_state['index'], len(data) - 1)
# Use the session_state index to display the row
example_row = data.iloc[st.session_state['index']]
example_row["prompt"] = dummy_text
default_prompt = example_row["prompt"]
prompt = st.text_area("Please type your prompt", default_prompt, height=300)
prompt_fn = format_chat_prompt
full_prompt = [prompt_fn(t) for t in [prompt]]
with torch.no_grad():
completion = model_wrapped.sample(
full_prompt, max_new_tokens=max_completion)[0]
if "openai" in model_name:
all_completion_tokens = model_wrapped.tokenizer.encode(completion)
all_target_tokens = [
model_wrapped.tokenizer.decode(
[t])[0] for t in all_completion_tokens]
else:
all_completion_tokens = model_wrapped.tokenizer(completion).input_ids
all_target_tokens = [model_wrapped.tokenizer.decode(
t) for t in all_completion_tokens]
markdown_prompt_text = full_prompt[0].replace("\n", "<br>")
prompt_markdown = f'Prompt:<br></span>{markdown_prompt_text}<br><span style="color:blue;">'
markdown_output_text = completion
st.markdown((
f'<div class="highlight">'
f'<span style="color:blue;">{model_name}:</span></span>'
f'<em>{markdown_output_text}</em>'
f'</div>'
), unsafe_allow_html=True)
st.write("### Explanations")
st.write("We will try to understand what parts of the input text leads to the model output.")
with st.form(key="explanation"):
selected_start, selected_end = st.select_slider(
"Select which token to start from.",
options=all_target_tokens,
value=(all_target_tokens[0], all_target_tokens[-1])
)
submit_button = st.form_submit_button("Hierarchically explain now mate.")
if submit_button:
explain_start_idx, explain_end_idx = all_target_tokens.index(
selected_start), all_target_tokens.index(selected_end)
explain_end_idx += 1
completion_len = explain_end_idx - explain_start_idx
completion_to_explain = model_wrapped.tokenizer.decode(
all_completion_tokens[explain_start_idx:explain_end_idx])
if explain_start_idx == 0:
prompt_to_explain = full_prompt[0]
else:
prompt_to_explain = full_prompt[0] + model_wrapped.tokenizer.decode(
all_completion_tokens[:explain_start_idx])
print(prompt_to_explain)
print(completion_to_explain, "completion_to_explain")
print("Completion len: ", completion_len)
st.write(f"We will explain: {completion_to_explain}")
num_sentences = len(split_into_sentences(prompt))
num_ngrams = len(extract_non_overlapping_ngrams(prompt, 1))
perturbation_hierarchy = []
n_total = 8
n_log = 4
if num_sentences > 1:
cur_n_sent = (num_sentences // n_total) + 1
while True:
perturbation_hierarchy.append({"partition_fn": "n_sentences", "partition_kwargs": {
"n": cur_n_sent}, "max_features": n_log})
cur_n_sent = ((cur_n_sent // 2) + (cur_n_sent % 2))
if cur_n_sent <= 1:
break
perturbation_hierarchy.append(
{"partition_fn": "sentences", "partition_kwargs": {}, "max_features": n_log})
if num_ngrams > 20:
cur_ngrams = 10
while True:
perturbation_hierarchy.append({"partition_fn": "ngrams", "partition_kwargs": {
"n": cur_ngrams}, "max_features": 5})
cur_ngrams = (cur_ngrams // 2) + (cur_ngrams % 2)
if cur_ngrams <= 9:
break
perturbation_hierarchy.append(
{"partition_fn": "ngrams", "partition_kwargs": {"n": 5}, "max_features": 5})
# perturbation_hierarchy.append({"partition_fn": "ngrams", "partition_kwargs": {"n": 1}, "max_features": 5})
print(perturbation_hierarchy)
explainer = LocalExplanationHierarchical(
perturbation_model="removal",
perturbation_hierarchy=perturbation_hierarchy,
progress_bar=ProgressBar,
step_explainer=step_explainer_class,
)
def intermediate_log_fn(hierarchy_step, step_idx):
st.subheader(f"Hierarchical Step-{step_idx}")
parts, parts_to_replace, attribution_scores = hierarchy_step
exp_text_step = explanation_text(parts, attribution_scores)
st.markdown(exp_text_step, unsafe_allow_html=True)
importance_cache = explainer.attribution(
model_wrapped,
prompt_to_explain,
completion_to_explain,
max_parts_to_perturb=max_parts_to_perturb,
max_features=max_features,
n_samples=num_perturbations,
prompt_fn=None,
intermediate_log_fn=intermediate_log_fn)