-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
52 lines (40 loc) · 1.44 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from datasets import load_dataset
from aitextgen.TokenDataset import TokenDataset
from transformers import AutoTokenizer
from aitextgen import aitextgen
MODEL="EleutherAI/gpt-neo-125M"
# MODEL = "EleutherAI/gpt-neo-1.3B"
class Trainer:
def __init__(self, model, artist_name):
self.model = model
self.artist_name = artist_name
self.ai = aitextgen(model=model, to_gpu=True)
self.tokenizer = AutoTokenizer.from_pretrained(self.model)
self.tokenizer.pad_token = self.tokenizer.eos_token
def file_name(self):
return f"lyric_texts/{self.artist_name}.txt"
def normalized_model_name(self):
return self.model.replace("/", "--").lower()
def model_dir(self):
return f"models/{self.artist_name}-{self.normalized_model_name()}"
def download_dataset(self):
self.ds = load_dataset(f"huggingartists/{self.artist_name}")
f=open(self.file_name(), 'w')
content = "\n".join([f"{self.tokenizer.bos_token}{x}{self.tokenizer.eos_token}\n" for x in self.ds["train"]["text"]])
f.write(content)
f.close()
def train(self):
data = TokenDataset(self.file_name(), tokenizer=self.tokenizer, block_size=64)
self.ai.train(data,
output_dir=self.model_dir(),
batch_size=1,
num_steps=5000,
save_every=1000,
generate_every=500)
if __name__ == '__main__':
artists = ['metallica', 'eminem', 'katy-perry']
for artist in artists:
trainer = Trainer(MODEL, artist)
trainer.download_dataset()
trainer.train()
del trainer