-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathNLP.py
76 lines (69 loc) · 3.85 KB
/
NLP.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from typing import Optional
import torch
import transformers
from transformers import GPT2LMHeadModel, PreTrainedTokenizerFast
from fastai.text.all import *
import fastai
import re
from datetime import datetime
class DropOutput(Callback):
def after_pred(self): self.learn.pred = self.pred[0]
class NLP:
def __init__(self):
# device = torch.device('cpu')
self.categories = ['it', 'business', 'marketing']
self.models = dict()
for category in self.categories:
self.models[category] = GPT2LMHeadModel.from_pretrained('./models/{}.pt'.format(category))
print("{} 로딩 완료".format(category))
# self.model = AutoModelWithLMHead.from_pretrained('./models/marketing/all_5epoch')
self.tokenizer = PreTrainedTokenizerFast.from_pretrained("skt/kogpt2-base-v2", bos_token='</s>',
eos_token='</s>',
unk_token='<unk>', pad_token='<pad>',
mask_token='<mask>')
def generate(self, category="total", prompt="제가 가장 중요하게 생각하는 것은", number=1):
# result = set()
# while len(result) < number:
# input_ids = self.tokenizer.encode(prompt)
# gen_ids = self.models[category].generate(torch.tensor([input_ids]),
# max_length=150,
# repetition_penalty=2.0,
# pad_token_id=self.tokenizer.pad_token_id,
# eos_token_id=self.tokenizer.eos_token_id,
# bos_token_id=self.tokenizer.bos_token_id,
# use_cache=True,
# temperature=1.1,
# top_k=50,
# top_p=0.8,
# do_sample=True)
# generated = self.tokenizer.decode(gen_ids[0, :].tolist())
# if generated.strip() not in result:
# generated = generated.replace('\n', ' ')
# generated = re.sub(r'\s+', ' ', generated)
# # 마침표 빠진 경우 추가
# generated = generated.replace('니다 ', '니다. ')
# result.add(generated)
# return list(result)
input_ids = self.tokenizer.encode(prompt)
res = []
preds = self.models[category].generate(torch.tensor([input_ids]),
max_length=100,
repetition_penalty=2.0,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
bos_token_id=self.tokenizer.bos_token_id,
use_cache=True,
temperature=1.1,
top_k=50,
top_p=0.9,
do_sample=True,
num_return_sequences=number
)
for i, pred in enumerate(preds):
res.append(self.tokenizer.decode(pred.tolist()))
print(res[-1])
return res
if __name__ == '__main__':
nlp = NLP()
print(*nlp.generate('total', '제가 가장 중요하다고 생각하는 것은', 3), sep='\n')
# nlp.save()