-
Notifications
You must be signed in to change notification settings - Fork 28
/
Copy pathtrain_gpt.ts
48 lines (40 loc) · 1.43 KB
/
train_gpt.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import "@tensorflow/tfjs-node"
import { AutoTokenizer } from "@xenova/transformers";
import { models, processing, Dataset } from "@epfml/discojs";
import { List } from "immutable";
async function main(): Promise<void> {
const data = "Lorem ipsum dolor sit amet, consectetur adipis"
const seed = 42
const config: models.GPTConfig = {
modelType: 'gpt-nano',
lr: 0.01,
maxIter: 50,
evaluateEvery:50,
maxEvalBatches: 10,
contextLength: 16,
seed
}
const tokenizer = await AutoTokenizer.from_pretrained('Xenova/gpt2')
const tokenDataset = new Dataset([data])
.map((text: string) => processing.tokenize(tokenizer, text))
.flatten()
.batch(config.contextLength + 1, 1)
.map((tokens) => [tokens.pop(), tokens.last()] as [List<number>, number])
.repeat()
.batch(8);
const model = new models.GPT(config)
for await (const logs of model.train(tokenDataset, undefined)) {
console.log(logs)
}
let tokens = processing.tokenize(tokenizer, "Lorem");
const maxNewTokens = 14
for (let n = 0; n < maxNewTokens; n++) {
const next = (await model.predict(List.of(tokens), { seed })).first();
if (next === undefined) throw new Error("empty prediction");
tokens = tokens.push(next)
}
const generation = tokenizer.decode(tokens.toArray(), { skip_special_tokens: true })
console.log(generation)
}
// You can run this example with "npm run run_gpt" from this folder
main().catch(console.error)