Skip to content

Commit cfedb7f

Browse files
committed
test GPT layers
1 parent 42c9d3f commit cfedb7f

File tree

2 files changed

+101
-2
lines changed

2 files changed

+101
-2
lines changed

discojs/src/models/gpt/layers.spec.ts

+99
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import * as tf from '@tensorflow/tfjs';
2+
import { expect } from 'chai';
3+
import { GELU, LMEmbedding } from './layers.js';
4+
5+
describe('GPT Layers', function () {
6+
// GELU Layer tests
7+
describe('GELU Layer', function () {
8+
9+
afterEach(() => {
10+
// Dispose of variables to avoid name collisions in subsequent tests.
11+
tf.disposeVariables();
12+
});
13+
14+
it('should compute GELU activation correctly for known inputs', async function () {
15+
const geluLayer = new GELU();
16+
17+
const input: tf.Tensor1D = tf.tensor1d([0, 1, -1, 2, -2]);
18+
19+
const output = geluLayer.apply(input) as tf.Tensor;
20+
const outputData: Float32Array = await output.data() as Float32Array;
21+
22+
// expected values based on the GELU tanh approximation
23+
const expected: number[] = [0, 0.8415, -0.1585, 1.955, -0.046];
24+
25+
for (let i = 0; i < expected.length; i++) {
26+
expect(outputData[i]).to.be.closeTo(expected[i], 0.05);
27+
}
28+
});
29+
});
30+
31+
// LMEmbedding Layer tests
32+
describe('LMEmbedding Layer', function () {
33+
34+
it('should return token embeddings with shape [batch_size, sequence_length, nEmbd] for 2D input', function () {
35+
const vocabSize = 100;
36+
const nEmbd = 16;
37+
const seed = 42;
38+
39+
const lmEmbedding = new LMEmbedding(vocabSize, nEmbd, seed);
40+
41+
// dummy 2D input representing token indices: shape [batch_size, sequence_length]
42+
const tokenIndices = tf.tensor2d([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]], [2, 5], 'int32');
43+
44+
const output = lmEmbedding.apply(tokenIndices) as tf.Tensor;
45+
46+
// expected output shape for 2D input: [2, 5, nEmbd]
47+
expect(output.shape).to.deep.equal([2, 5, nEmbd]);
48+
});
49+
50+
it('should return token logits with shape [batch_size, sequence_length, vocabSize] for 3D input', function () {
51+
const vocabSize = 100;
52+
const nEmbd = 16;
53+
const seed = 42;
54+
55+
const lmEmbedding = new LMEmbedding(vocabSize, nEmbd, seed);
56+
57+
// dummy 3D input representing a batch of embeddings: shape [batch_size, sequence_length, nEmbd]
58+
const embeddingsInput = tf.randomUniform([2, 5, nEmbd]);
59+
60+
const output = lmEmbedding.apply(embeddingsInput) as tf.Tensor;
61+
62+
// expected output shape for 3D input: [2, 5, vocabSize]
63+
expect(output.shape).to.deep.equal([2, 5, vocabSize]);
64+
});
65+
66+
it('should throw an error for unexpected input shape', function () {
67+
const vocabSize = 100;
68+
const nEmbd = 16;
69+
const seed = 42;
70+
71+
const lmEmbedding = new LMEmbedding(vocabSize, nEmbd, seed);
72+
73+
// invalid input tensor with 1D shape.
74+
const invalidInput = tf.tensor1d([1, 2, 3], 'int32');
75+
76+
expect(() => lmEmbedding.apply(invalidInput)).to.throw('unexpected input shape');
77+
});
78+
79+
it('should throw an error if input is an array with more than one tensor', function () {
80+
const vocabSize = 100;
81+
const nEmbd = 16;
82+
const seed = 42;
83+
const lmEmbedding = new LMEmbedding(vocabSize, nEmbd, seed);
84+
const input1 = tf.tensor2d([[1, 2, 3]], [1, 3], 'int32');
85+
const input2 = tf.tensor2d([[4, 5, 6]], [1, 3], 'int32');
86+
expect(() => lmEmbedding.apply([input1, input2])).to.throw('expected exactly one tensor');
87+
});
88+
89+
it('should compute correct output shape for 2D input using computeOutputShape', function () {
90+
const vocabSize = 100;
91+
const nEmbd = 16;
92+
const seed = 42;
93+
const lmEmbedding = new LMEmbedding(vocabSize, nEmbd, seed);
94+
const outputShape = lmEmbedding.computeOutputShape([null, null]);
95+
expect(outputShape).to.deep.equal([null, null, nEmbd]);
96+
});
97+
98+
});
99+
});

discojs/src/models/gpt/layers.ts

+2-2
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ tf.serialization.registerClass(CausalSelfAttention)
228228
*
229229
* https://pytorch.org/docs/stable/generated/torch.nn.GELU.html
230230
*/
231-
class GELU extends tf.layers.Layer {
231+
export class GELU extends tf.layers.Layer {
232232
static readonly className = 'GELU'
233233

234234
constructor () {
@@ -368,7 +368,7 @@ function TransformerBlock (conf: BlockConfig): tf.LayersModel {
368368
* that can be used for both the token embeddings and the language modeling head.
369369
* In the GPT2 model definition, this layers corresponds to wte and lm_head (which reuses wte)
370370
*/
371-
class LMEmbedding extends tf.layers.Layer {
371+
export class LMEmbedding extends tf.layers.Layer {
372372
static readonly className = 'LMEmbedding'
373373
embeddings?: tf.LayerVariable
374374

0 commit comments

Comments
 (0)