Skip to content

Commit bee8642

Browse files
committed
fix: test
1 parent a015628 commit bee8642

File tree

2 files changed

+33
-28
lines changed

2 files changed

+33
-28
lines changed

src/nn/nn-vulkan-test.cpp

+17-17
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,10 @@ void testRmsNorm_F32_F32_F32() {
7272
float *xPipe = (float *)execution->pipes[0];
7373
for (NnUint b = 0; b < batchSize; b++) {
7474
float *xBatchPipe = &xPipe[b * RMS_NORM_DIM];
75-
for (NnUint i = 0; i < RMS_NORM_DIM; i++)
76-
xBatchPipe[i] = (float)(RMS_NORM_DIM - i) / (float)(RMS_NORM_DIM / 2);
75+
for (NnUint i = 0; i < RMS_NORM_DIM; i++) {
76+
float u = (float)(RMS_NORM_DIM - i + b) / (float)(RMS_NORM_DIM / 2);
77+
xBatchPipe[i] = u;
78+
}
7779
}
7880

7981
// act
@@ -83,22 +85,20 @@ void testRmsNorm_F32_F32_F32() {
8385
float invRmsBuffer[N_BATCHES];
8486
device->data->buffers[0].get()->read((NnByte *)invRmsBuffer);
8587

88+
float expectedS[N_BATCHES];
89+
expectedS[0] = 0.863493f;
90+
expectedS[1] = 0.858468f;
91+
8692
for (NnUint b = 0; b < batchSize; b++) {
8793
float *xBatchPipe = &xPipe[b * RMS_NORM_DIM];
8894

89-
float t = 0.000001f;
90-
assertFloat(b, invRmsBuffer[b], 0.863493f, t);
91-
assertFloat(0, xBatchPipe[0], 0.001687f, t);
92-
assertFloat(1, xBatchPipe[1], 0.008400f, t);
93-
assertFloat(2, xBatchPipe[2], 0.015060f, t);
94-
assertFloat(35, xBatchPipe[35], 0.205286f, t);
95-
assertFloat(36, xBatchPipe[36], 0.210155f, t);
96-
assertFloat(119, xBatchPipe[119], 0.430514f, t);
97-
assertFloat(123, xBatchPipe[123], 0.431964f, t);
98-
assertFloat(234, xBatchPipe[234], 0.135804f, t);
99-
assertFloat(242, xBatchPipe[242], 0.089372f, t);
100-
assertFloat(249, xBatchPipe[249], 0.045977f, t);
101-
assertFloat(255, xBatchPipe[255], 0.006726f, t);
95+
const float t = 0.000001f;
96+
const float s = expectedS[b];
97+
assertFloat(b, invRmsBuffer[b], s, t);
98+
for (NnUint i = 0; i < RMS_NORM_DIM; i++) {
99+
float u = (float)(RMS_NORM_DIM - i + b) / (float)(RMS_NORM_DIM / 2);
100+
assertFloat(b * RMS_NORM_DIM + i, xBatchPipe[i], (u * s) * normWeight[i], t);
101+
}
102102
}
103103
printOk("testRmsNorm_F32_F32_F32");
104104
});
@@ -165,7 +165,7 @@ void testMul_F32_F32() {
165165
float sBuffer[MUL_DIM * N_BATCHES];
166166
for (NnUint i = 0; i < MUL_DIM * N_BATCHES; i++) {
167167
xPipe[i] = (float)i;
168-
sBuffer[i] = cosf((float)i);
168+
sBuffer[i] = (i % 8) / 10.0f;
169169
}
170170

171171
device->data->buffers[0].get()->write((NnByte *)sBuffer);
@@ -175,7 +175,7 @@ void testMul_F32_F32() {
175175

176176
// assert
177177
for (NnUint i = 0; i < MUL_DIM * N_BATCHES; i++)
178-
assertFloat(i, xPipe[i], i * cosf((float)i), 0.00001f);
178+
assertFloat(i, xPipe[i], i * ((i % 8) / 10.0f), 0.000001f);
179179
printOk("testMul_F32_F32");
180180
});
181181
}

src/nn/vulkan/rms-norm-forward-f32-f32-f32.comp

+16-11
Original file line numberDiff line numberDiff line change
@@ -16,25 +16,30 @@ layout(binding = 1) writeonly buffer outputBuffer { float y[]; };
1616
layout(binding = 2) readonly buffer batchInfosBuffer { BatchInfo infos[]; };
1717
layout(binding = 3) readonly buffer weightBuffer { float weight[]; };
1818
layout(binding = 4) readonly uniform configBuffer {
19-
uint invRmsBufferIndex;
19+
uint invRmsBufferIndex; // not used
2020
};
2121
layout(binding = 5) readonly buffer invRmsBuffer { float invRms[]; };
2222

23+
shared BatchInfo sharedInfo;
24+
shared float s;
25+
2326
void main() {
2427
const uint threadIndex = uint(gl_LocalInvocationID.x);
2528
const uint batchIndex = uint(gl_GlobalInvocationID.y);
2629

27-
const uint inputSizeX = infos[batchIndex].inputSizeX;
28-
const uint offset = infos[batchIndex].inputOffset;
29-
const uint slice = inputSizeX / N_THREADS;
30-
const uint rest = inputSizeX % N_THREADS;
31-
const uint start = threadIndex * slice + (threadIndex < rest ? threadIndex : rest);
32-
const uint end = start + slice + (threadIndex < rest ? 1 : 0);
30+
if (threadIndex == 0) {
31+
sharedInfo = infos[batchIndex];
32+
s = invRms[batchIndex];
33+
}
34+
35+
barrier();
36+
memoryBarrierShared();
3337

34-
const float s = invRms[batchIndex];
38+
const uint inputSizeX = sharedInfo.inputSizeX;
39+
const uint xOffset = sharedInfo.inputOffset;
40+
const uint yOffset = sharedInfo.outputOffset;
3541

36-
for (uint i = start; i < end; i++) {
37-
uint j = offset + i;
38-
y[j] = (x[j] * s) * weight[i];
42+
for (uint i = threadIndex; i < inputSizeX; i += N_THREADS) {
43+
y[yOffset + i] = (x[xOffset + i] * s) * weight[i];
3944
}
4045
}

0 commit comments

Comments
 (0)