-
Notifications
You must be signed in to change notification settings - Fork 28
/
Copy pathmodel.ts
105 lines (89 loc) · 3.57 KB
/
model.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import createDebug from "debug";
import * as tf from '@tensorflow/tfjs'
import type { GPTConfig } from './config.js'
import { getModelSizes, DefaultGPTConfig } from './config.js'
import { getCustomAdam, clipByGlobalNormObj } from './optimizers.js'
import { GPTArchitecture } from './layers.js'
const debug = createDebug("discojs:models:gpt:model");
/**
* tfjs does not export LazyIterator and Dataset...
*/
declare abstract class LazyIterator<T> {
abstract next (): Promise<IteratorResult<T>>
}
export declare abstract class Dataset<T> {
abstract iterator (): Promise<LazyIterator<T>>
size: number
}
/**
* GPTModel extends tf.LayersModel and overrides tfjs' default training loop
*
*/
export class GPTModel extends tf.LayersModel {
protected readonly config: Required<GPTConfig>
constructor(partialConfig?: Partial<GPTConfig>, layersModel?: tf.LayersModel) {
// Fill missing config parameters with default values
let completeConfig: Required<GPTConfig> = { ...DefaultGPTConfig, ...partialConfig }
// Add layer sizes depending on which model has been specified
completeConfig = { ...completeConfig, ...getModelSizes(completeConfig.modelType) }
if (layersModel !== undefined) {
super({ inputs: layersModel.inputs, outputs: layersModel.outputs,name: layersModel.name })
} else {
const gpt = GPTArchitecture(completeConfig)
const { inputs, outputs, name } = gpt
super({ inputs, outputs, name })
}
this.config = completeConfig
}
get getGPTConfig() {
return this.config
}
override compile() {
if (this.optimizer !== undefined) return
this.optimizer = this.config.weightDecay !== 0
? getCustomAdam(this, this.config.lr, this.config.weightDecay)
: tf.train.adam(this.config.lr)
}
override async trainOnBatch(x: tf.Tensor, y: tf.Tensor): Promise<number | number[]> {
let weightUpdateTime = performance.now()
let preprocessingTime = performance.now()
await Promise.all([x.data(), y.data()])
preprocessingTime = performance.now() - preprocessingTime
let logitsTensor: tf.Tensor<tf.Rank>;
const lossTensor = tf.tidy(() => {
const { grads, value: lossTensor } = this.optimizer.computeGradients(() => {
const logits = this.apply(x)
if (Array.isArray(logits))
throw new Error('model outputs too many tensor')
if (logits instanceof tf.SymbolicTensor)
throw new Error('model outputs symbolic tensor')
logitsTensor = tf.keep(logits)
return tf.losses.softmaxCrossEntropy(y, logits)
})
const gradsClipped = clipByGlobalNormObj(grads, 1)
this.optimizer.applyGradients(gradsClipped)
return lossTensor
})
// @ts-expect-error Variable 'logitsTensor' is used before being assigned
const accTensor = tf.metrics.categoricalAccuracy(y, logitsTensor)
const accSize = accTensor.shape.reduce((l, r) => l * r, 1)
const accSumTensor = accTensor.sum()
const accSum = await accSumTensor.array()
if (typeof accSum !== 'number')
throw new Error('got multiple accuracy sum')
// @ts-expect-error Variable 'logitsTensor' is used before being assigned
tf.dispose([accTensor, accSumTensor, logitsTensor])
const loss = await lossTensor.array()
weightUpdateTime = performance.now() - weightUpdateTime
tf.dispose([x, y, lossTensor])
const memory = tf.memory().numBytes / 1024 / 1024 / 1024
debug("training metrics: %O", {
loss,
memory,
allocated: tf.memory().numTensors,
preprocessingTime,
weightUpdateTime,
});
return [loss, accSum / accSize]
}
}