-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluate_ppl.py
126 lines (88 loc) · 3.33 KB
/
evaluate_ppl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import os
import argparse
from tqdm import tqdm
import torch
import pandas as pd
import json
def get_all_texts(split):
df = pd.read_csv(f'twitter_split_{split}.csv', header=None)
texts = df[6].to_list()
print(texts[:10])
return texts[:20000]
def get_all_texts_clean(split):
df = pd.read_csv(f'twitter_split_{split}.csv', header=None)
texts = df[6].to_list()
print(texts[:10])
return [t.strip('<|endoftext|>') for t in texts][:20000]
class TextDataset(torch.utils.data.Dataset):
def __init__(self, texts):
self.train_texts = texts
def __len__(self):
return len(self.train_texts)
def __getitem__(self, index):
return self.train_texts[index]
def eval_ppl(texts_eval, model_name):
print(f"evaluating {model_name}")
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
model = GPT2LMHeadModel.from_pretrained(model_name)
model = model.to('cuda:2')
test_dataset = TextDataset(texts_eval)
eval_dataloader = torch.utils.data.DataLoader(test_dataset, shuffle=True, batch_size=16)
model.eval()
epoch_loss = 0
losses = []
with torch.no_grad():
for texts in tqdm(eval_dataloader):
text_tokenized = tokenizer(texts, padding = True, truncation = True, max_length = 70, return_tensors='pt').input_ids.to('cuda:2')
loss = model(text_tokenized, labels=text_tokenized).loss
epoch_loss += loss.item()
losses.append(loss.item())
print("evaluation loss", 16*epoch_loss/len(test_dataset))
return losses
if __name__ == '__main__':
texts_eval = get_all_texts_clean(1)
l1_ = eval_ppl(texts_eval, "gpt2")
l2_ = eval_ppl(texts_eval, "twitter_model_target_epoch_4")
non_member_differences = [l2-l1 for l1, l2 in zip(l1_, l2_)]
texts_eval = get_all_texts_clean(0)
l1_ = eval_ppl(texts_eval, "gpt2")
l2_ = eval_ppl(texts_eval, "twitter_model_target_epoch_4")
member_differences = [l2-l1 for l1, l2 in zip(l1_, l2_)]
print("lengths", len(member_differences))
print("mean diff members", sum(member_differences)/len(member_differences))
print("mean diff non members", sum(non_member_differences)/len(non_member_differences))
prev_fpr = 0
factor = 1
for i in range(10000):
median_index = i
median = sorted(member_differences+non_member_differences)[median_index]
tp = 0
fn = 0
for diff in member_differences:
if diff <= median:
tp += 1
else:
fn += 1
tn = 0
fp = 0
for diff in non_member_differences:
if diff > median:
tn += 1
else:
fp += 1
if prev_fpr < 0.1 and fp/(fp+tn) >= 0.1:
print("tpr", tp/(tp+fn))
print("fpr", fp/(fp+tn))
break
if prev_fpr < 0.01 and fp/(fp+tn) >= 0.01:
print("tpr", tp/(tp+fn))
print("fpr", fp/(fp+tn))
if prev_fpr < 0.001 and fp/(fp+tn) >= 0.001:
print("tpr", tp/(tp+fn))
print("fpr", fp/(fp+tn))
if prev_fpr < 0.0001 and fp/(fp+tn) >= 0.0001:
print("tpr", tp/(tp+fn))
print("fpr", fp/(fp+tn))
prev_fpr = fp/(fp+tn)