Skip to content

Commit 22476bc

Browse files
committed
/discojs/src/models/gpt/ layers tests + causal attention split in functions
1 parent cfedb7f commit 22476bc

File tree

2 files changed

+477
-93
lines changed

2 files changed

+477
-93
lines changed

discojs/src/models/gpt/layers.spec.ts

+203-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import * as tf from '@tensorflow/tfjs';
22
import { expect } from 'chai';
3-
import { GELU, LMEmbedding } from './layers.js';
3+
import { GELU, LMEmbedding, Range, MLP, MLPConfig, CausalSelfAttention, CausalSelfAttentionConfig } from './layers.js';
44

55
describe('GPT Layers', function () {
66
// GELU Layer tests
@@ -20,7 +20,7 @@ describe('GPT Layers', function () {
2020
const outputData: Float32Array = await output.data() as Float32Array;
2121

2222
// expected values based on the GELU tanh approximation
23-
const expected: number[] = [0, 0.8415, -0.1585, 1.955, -0.046];
23+
const expected: number[] = [0, 0.8412, -0.1588, 1.955, -0.045];
2424

2525
for (let i = 0; i < expected.length; i++) {
2626
expect(outputData[i]).to.be.closeTo(expected[i], 0.05);
@@ -96,4 +96,205 @@ describe('GPT Layers', function () {
9696
});
9797

9898
});
99+
100+
// Range Layer tests
101+
describe('Range Layer', function () {
102+
103+
afterEach(() => {
104+
// dispose any created tensors/variables
105+
tf.disposeVariables();
106+
});
107+
108+
it('should output a tensor with shape [1, T] for an input of shape [batch, T]', async function () {
109+
const rangeLayer = new Range();
110+
111+
// dummy input tensor with shape [batch, T]
112+
const dummyInput = tf.zeros([3, 10], 'int32');
113+
114+
const output = rangeLayer.apply(dummyInput) as tf.Tensor;
115+
116+
// We expect the output to have shape [1, T] i.e. [1, 10]
117+
expect(output.shape).to.deep.equal([1, 10]);
118+
119+
// verify the content: the layer should output a range [0, 1, ..., T-1]
120+
const outputData = await output.data();
121+
for (let i = 0; i < 10; i++) {
122+
expect(outputData[i]).to.equal(i);
123+
}
124+
});
125+
});
126+
127+
// MLP Layer tests
128+
describe('MLP Layer', function () {
129+
130+
it('should produce deterministic outputs with the same random seed', async function () {
131+
// an MLP config with a fixed seed
132+
const config: MLPConfig = {
133+
name: 'testMLP',
134+
contextLength: 10,
135+
residDrop: 0, // no dropout for deterministic behavior
136+
nLayer: 2,
137+
seed: 42,
138+
nEmbd: 16,
139+
nHead: 4
140+
};
141+
142+
// two separate MLP model instances using the same config
143+
const model1 = MLP(config);
144+
const model2 = MLP(config);
145+
146+
const input = tf.ones([1, config.contextLength, config.nEmbd]);
147+
148+
// get predictions from both models
149+
const output1 = model1.predict(input) as tf.Tensor;
150+
const output2 = model2.predict(input) as tf.Tensor;
151+
152+
const arr1 = await output1.data();
153+
const arr2 = await output2.data();
154+
155+
// check lengths are equal
156+
expect(arr1.length).to.equal(arr2.length);
157+
158+
// check that the models produce the same output
159+
expect(arr1).to.deep.equal(arr2);
160+
161+
});
162+
});
163+
164+
// CausalSelfAttention Layer tests
165+
describe('CausalSelfAttention Helper Methods', function () {
166+
167+
const config: CausalSelfAttentionConfig = {
168+
name: 'testCSA',
169+
contextLength: 5,
170+
nHead: 2,
171+
nEmbd: 8, // divisible by nHead, so head size = 4
172+
dropout: 0.0, // no dropout for deterministic tests
173+
nLayer: 2,
174+
seed: 42
175+
};
176+
177+
let csa: CausalSelfAttention;
178+
179+
// new instance of CausalSelfAttention before each test
180+
beforeEach(() => {
181+
csa = new CausalSelfAttention(config);
182+
// dummy input has shape [batch, T, nEmbd] = [1, contextLength, nEmbd].
183+
const dummyInput = tf.zeros([1, config.contextLength, config.nEmbd], 'float32');
184+
csa.apply(dummyInput);
185+
});
186+
187+
afterEach(() => {
188+
tf.disposeVariables();
189+
});
190+
191+
// describe('_dense', function () {
192+
// it('should compute x * kernel + bias correctly using addWeight', async function () {
193+
// const x = tf.tensor2d([[1, 2]], [1, 2]);
194+
// const kernel = csa.addWeight(
195+
// 'dense_test_kernel',
196+
// [2, 2],
197+
// 'float32',
198+
// tf.initializers.constant({ value: [[1, 0], [0, 1]] })
199+
// ) as tf.layers.LayerVariable;
200+
// const bias = csa.addWeight(
201+
// 'dense_test_bias',
202+
// [2],
203+
// 'float32',
204+
// tf.initializers.constant({ value: [0.5, -0.5] })
205+
// ) as tf.layers.LayerVariable;
206+
207+
// const output = csa._dense(x, kernel, bias);
208+
// const outData = await output.data();
209+
// // Expected calculation:
210+
// // [1,2] dot [[1,0],[0,1]] = [1,2] and then add bias [0.5, -0.5] gives [1.5, 1.5]
211+
// expect(Array.from(outData)).to.deep.equal([1.5, 1.5]);
212+
// });
213+
// });
214+
215+
describe('_splitHeads', function () {
216+
it('should reshape and transpose the input correctly', function () {
217+
const B = 2;
218+
const T = 6;
219+
const totalChannels = config.nEmbd; // 8 channels
220+
// input tensor with shape [B, T, totalChannels]
221+
const input = tf.tensor3d(new Array(B * T * totalChannels).fill(1), [B, T, totalChannels]);
222+
const output = csa._splitHeads(input, B, T, config.nHead);
223+
// expected shape: [B, nHead, T, totalChannels/nHead] = [2, 2, 6, 4]
224+
expect(output.shape).to.deep.equal([B, config.nHead, T, totalChannels / config.nHead]);
225+
});
226+
});
227+
228+
describe('_applyCausalMask', function () {
229+
it('should produce a causal mask that sets upper-triangular positions to -1e9', async function () {
230+
const T = config.contextLength;
231+
// dummy attention logits tensor with shape [1, 1, T, T] filled with zeros
232+
const att = tf.zeros([1, 1, T, T], 'float32');
233+
const masked = csa._applyCausalMask(att, T);
234+
const data = await masked.data();
235+
// for each position (i,j): if j > i expect -1e9 else 0
236+
const expected: number[] = [];
237+
for (let i = 0; i < T; i++) {
238+
for (let j = 0; j < T; j++) {
239+
expected.push(j > i ? -1e9 : 0);
240+
}
241+
}
242+
expect(Array.from(data)).to.deep.equal(expected);
243+
});
244+
});
245+
246+
describe('_computeAttention', function () {
247+
it('should output attention weights that sum to 1 over the last dimension', async function () {
248+
const B = 1;
249+
const nHead = config.nHead;
250+
const T = config.contextLength;
251+
const headSize = config.nEmbd / config.nHead;
252+
const q = tf.randomUniform([B, nHead, T, headSize]);
253+
const k = tf.randomUniform([B, nHead, T, headSize]);
254+
const att = csa._computeAttention(q, k, false, T);
255+
// expected shape: [B, nHead, T, T]
256+
expect(att.shape).to.deep.equal([B, nHead, T, T]);
257+
// check that each row of the attention logits (last dimension) sums to approximately 1
258+
const attData = await att.data();
259+
const attArray = Array.from(attData);
260+
for (let b = 0; b < B; b++) {
261+
for (let h = 0; h < nHead; h++) {
262+
for (let i = 0; i < T; i++) {
263+
// calculate the starting index for the i-th row in the flattened tensor
264+
const rowStart = b * nHead * T * T + h * T * T + i * T;
265+
const row = attArray.slice(rowStart, rowStart + T);
266+
const rowSum = row.reduce((sum, val) => sum + val, 0);
267+
expect(rowSum).to.be.closeTo(1, 1e-3);
268+
}
269+
}
270+
}
271+
});
272+
});
273+
274+
// describe('_projectOutput', function () {
275+
// it('should project the input correctly using dense operation with addWeight', async function () {
276+
// const x = tf.tensor2d([[1, 2, 3]], [1, 3]);
277+
// const projKernel = csa.addWeight(
278+
// 'project_test_kernel',
279+
// [3, 2],
280+
// 'float32',
281+
// tf.initializers.constant({ value: [[1, 0], [0, 1], [1, -1]] })
282+
// ) as tf.layers.LayerVariable;
283+
// const projBias = csa.addWeight(
284+
// 'project_test_bias',
285+
// [2],
286+
// 'float32',
287+
// tf.initializers.constant({ value: [0.5, 0.5] })
288+
// ) as tf.layers.LayerVariable;
289+
290+
// const output = csa._projectOutput(x, projKernel, projBias);
291+
// const data = await output.data();
292+
// // Calculation:
293+
// // [1,2,3] dot kernel = [1*1+2*0+3*1, 1*0+2*1+3*(-1)] = [4, -1]
294+
// // Then add bias [0.5, 0.5] = [4.5, -0.5]
295+
// expect(Array.from(data)).to.deep.equal([4.5, -0.5]);
296+
// });
297+
// });
298+
});
299+
99300
});

0 commit comments

Comments
 (0)