Skip to content

Commit da41445

Browse files
committed
testCast_F32_Q80.
1 parent 2d0bbaf commit da41445

File tree

3 files changed

+149
-22
lines changed

3 files changed

+149
-22
lines changed

src/nn/nn-vulkan-test.cpp

+48-13
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include <cstdio>
22
#include "nn-config-builder.hpp"
3+
#include "nn-quants.hpp"
34
#include "nn-vulkan.hpp"
45

56
#define N_BATCHES 4
@@ -309,11 +310,11 @@ void testShift_F32_F32() {
309310
}
310311

311312
void testCast_F32_F32() {
312-
#define CAST_DIM 64
313+
#define CAST_F32_DIM 64
313314
execute(
314315
[](NnNetConfigBuilder *netBuilder, NnNodeConfigBuilder *nodeBuilder, NnSegmentConfigBuilder *segmentBuilder) {
315-
NnUint xPipeIndex = netBuilder->addPipe("X", size2D(F_32, N_BATCHES, CAST_DIM));
316-
NnUint yPipeIndex = netBuilder->addPipe("Y", size2D(F_32, N_BATCHES, CAST_DIM));
316+
NnUint xPipeIndex = netBuilder->addPipe("X", size2D(F_32, N_BATCHES, CAST_F32_DIM));
317+
NnUint yPipeIndex = netBuilder->addPipe("Y", size2D(F_32, N_BATCHES, CAST_F32_DIM));
317318
segmentBuilder->addOp(
318319
OP_CAST, "cast", 0,
319320
pointerBatchConfig(SRC_PIPE, xPipeIndex),
@@ -327,25 +328,56 @@ void testCast_F32_F32() {
327328
float *xPipe = (float *)execution->pipes[0];
328329
float *yPipe = (float *)execution->pipes[1];
329330

330-
for (NnUint b = 0; b < N_BATCHES; b++) {
331-
for (NnUint i = 0; i < CAST_DIM; i++)
332-
xPipe[b * CAST_DIM + i] = (float)b;
333-
}
331+
for (NnUint i = 0; i < N_BATCHES * CAST_F32_DIM; i++)
332+
xPipe[i] = (float)(i + 1);
334333

335334
// act
336335
executor->forward();
337336

338337
// assert
339-
for (NnUint b = 0; b < N_BATCHES; b++) {
340-
for (NnUint i = 0; i < CAST_DIM; i++) {
341-
NnUint j = b * CAST_DIM + i;
342-
assertFloat(j, yPipe[j], (float)b, 0.00001f);
343-
}
344-
}
338+
for (NnUint i = 0; i < N_BATCHES * CAST_F32_DIM; i++)
339+
assertFloat(i, yPipe[i], (float)(i + 1), 0.00001f);
345340
printOk("testCast_F32_F32");
346341
});
347342
}
348343

344+
void testCast_F32_Q80() {
345+
#define CAST_Q80_DIM 256
346+
execute(
347+
[](NnNetConfigBuilder *netBuilder, NnNodeConfigBuilder *nodeBuilder, NnSegmentConfigBuilder *segmentBuilder) {
348+
NnUint xPipeIndex = netBuilder->addPipe("X", size2D(F_32, N_BATCHES, CAST_Q80_DIM));
349+
NnUint yPipeIndex = netBuilder->addPipe("Y", size2D(F_Q80, N_BATCHES, CAST_Q80_DIM));
350+
segmentBuilder->addOp(
351+
OP_CAST, "cast", 0,
352+
pointerBatchConfig(SRC_PIPE, xPipeIndex),
353+
pointerBatchConfig(SRC_PIPE, yPipeIndex),
354+
size0(),
355+
NnCastOpCodeConfig{});
356+
},
357+
[](NnExecutor *executor, NnNetExecution *execution, NnVulkanDevice *device) {
358+
// arrange
359+
execution->setBatchSize(N_BATCHES);
360+
float *xPipe = (float *)execution->pipes[0];
361+
NnBlockQ80 *yPipe = (NnBlockQ80 *)execution->pipes[1];
362+
363+
for (NnUint i = 0; i < N_BATCHES * CAST_Q80_DIM; i++)
364+
xPipe[i] = (float)(i + 1);
365+
366+
// act
367+
executor->forward();
368+
369+
float yF32[CAST_Q80_DIM * N_BATCHES];
370+
dequantizeQ80toF32(yPipe, yF32, CAST_Q80_DIM * N_BATCHES, 1, 0);
371+
372+
for (NnUint i = 0; i < N_BATCHES * CAST_Q80_DIM; i++) {
373+
const float expectedV = (float)(i + 1);
374+
const float change = (yF32[i] - expectedV) / expectedV;
375+
assertFloat(i, change, 0.0, 0.009f);
376+
}
377+
printOk("testCast_F32_Q80");
378+
});
379+
}
380+
349381
void testRope_F32_F32() {
350382
#define ROPE_DIM 2048
351383
#define ROPE_KV_DIM 512
@@ -490,13 +522,16 @@ void multiheadAtt_F32_F32() {
490522
}
491523

492524
int main() {
525+
initQuants();
526+
493527
testRmsNorm_F32_F32_F32();
494528
testSilu_F32_F32();
495529
testMul_F32_F32();
496530
testMergeAdd_F32_F32();
497531
testEmbedding_F32_F32();
498532
testShift_F32_F32();
499533
testCast_F32_F32();
534+
testCast_F32_Q80();
500535
testRope_F32_F32();
501536
matmul_F32_F32_F32();
502537
multiheadAtt_F32_F32();

src/nn/nn-vulkan.cpp

+23-9
Original file line numberDiff line numberDiff line change
@@ -242,23 +242,36 @@ NnUint NnVulkanDeviceData::resolveBufferBatchOffset(NnPointerConfig *config, NnU
242242
assert(batchIndex < netConfig->nBatches);
243243
if (config->type == PNTR_RAW)
244244
return 0;
245-
NnSize2D bufferSize = resolveBufferSize(config);
245+
246+
const NnSize2D bufferSize = resolveBufferSize(config);
247+
const NnSize blockSize = getBlockSize(bufferSize.floatType);
248+
assert(bufferSize.x % blockSize == 0);
249+
const NnUint sizeX = bufferSize.x / blockSize;
250+
246251
if (config->type == PNTR_BATCH)
247-
return bufferSize.x * batchIndex;
248-
if (config->type == PNTR_BATCHED_SLICE)
249-
return bufferSize.x * batchIndex + (bufferSize.x / netConfig->nNodes) * nodeConfig->nodeIndex;
252+
return sizeX * batchIndex;
253+
if (config->type == PNTR_BATCHED_SLICE) {
254+
assert(sizeX % netConfig->nNodes == 0);
255+
return sizeX * batchIndex + (sizeX / netConfig->nNodes) * nodeConfig->nodeIndex;
256+
}
250257
throw std::runtime_error("Cannot determine buffer offset");
251258
}
252259

253260
NnUint NnVulkanDeviceData::resolveBufferBatchWidth(NnPointerConfig *config, NnUint batchIndex) {
254261
assert(batchIndex < netConfig->nBatches);
255-
NnSize2D bufferSize = resolveBufferSize(config);
262+
const NnSize2D bufferSize = resolveBufferSize(config);
263+
const NnSize blockSize = getBlockSize(bufferSize.floatType);
264+
assert(bufferSize.x % blockSize == 0);
265+
const NnUint sizeX = bufferSize.x / blockSize;
266+
256267
if (config->type == PNTR_RAW)
257-
return bufferSize.x;
268+
return sizeX;
258269
if (config->type == PNTR_BATCH)
259-
return bufferSize.x;
260-
if (config->type == PNTR_BATCHED_SLICE)
261-
return bufferSize.x / netConfig->nNodes;
270+
return sizeX;
271+
if (config->type == PNTR_BATCHED_SLICE) {
272+
assert(sizeX % netConfig->nNodes == 0);
273+
return sizeX / netConfig->nNodes;
274+
}
262275
throw std::runtime_error("Cannot determine buffer width");
263276
}
264277

@@ -389,6 +402,7 @@ static const char *getShaderFileName(const NnOpCode opCode, const NnOpQuantType
389402
}
390403
if (opCode == OP_CAST) {
391404
if (quantType == F32_F32_F32) return "cast-forward-f32-f32.spv";
405+
if (quantType == F32_F32_Q80) return "cast-forward-f32-q80.spv";
392406
}
393407
if (opCode == OP_SHIFT) {
394408
if (quantType == F32_F32_F32) return "shift-forward-f32-f32.spv";
+78
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
#version 450
2+
3+
#extension GL_EXT_shader_16bit_storage : enable
4+
#extension GL_EXT_shader_explicit_arithmetic_types : enable
5+
6+
#define Q80_BLOCK_SIZE 32
7+
#define N_THREADS 64
8+
9+
layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in;
10+
11+
struct BatchInfo {
12+
uint inputOffset;
13+
uint inputSizeX;
14+
uint outputOffset; // number of Q80 blocks
15+
uint outputSizeX; // number of Q80 blocks
16+
};
17+
18+
struct BlockQ80 {
19+
float16_t d;
20+
int8_t qs[Q80_BLOCK_SIZE];
21+
};
22+
23+
layout(binding = 0) readonly buffer inputBuffer { float x[]; };
24+
layout(binding = 1) writeonly buffer outputBuffer { BlockQ80 y[]; };
25+
layout(binding = 2) readonly buffer batchInfosBuffer { BatchInfo infos[]; };
26+
27+
shared uint sharedYStart;
28+
shared uint sharedYEnd;
29+
shared uint sharedXOffset;
30+
shared uint sharedYOffset;
31+
32+
void main() {
33+
const uint threadIndex = gl_LocalInvocationID.x;
34+
35+
if (threadIndex == 0) {
36+
const uint nWorkGroups = gl_NumWorkGroups.z;
37+
const uint batchIndex = gl_WorkGroupID.y;
38+
const uint workGroupIndex = gl_WorkGroupID.z;
39+
40+
const BatchInfo info = infos[batchIndex];
41+
42+
const uint ySlice = info.outputSizeX / nWorkGroups;
43+
const uint yRest = info.outputSizeX % nWorkGroups;
44+
sharedYStart = workGroupIndex * ySlice + (workGroupIndex < yRest ? workGroupIndex : yRest);
45+
sharedYEnd = sharedYStart + ySlice + (workGroupIndex < yRest ? 1 : 0);
46+
sharedXOffset = info.inputOffset;
47+
sharedYOffset = info.outputOffset;
48+
}
49+
50+
barrier();
51+
memoryBarrierShared();
52+
53+
const uint yStart = sharedYStart + threadIndex;
54+
const uint yEnd = sharedYEnd;
55+
const uint xOffset = sharedXOffset;
56+
const uint yOffset = sharedYOffset;
57+
58+
for (uint i = yStart; i < yEnd; i += N_THREADS) {
59+
const uint xiOffset = xOffset + i * Q80_BLOCK_SIZE;
60+
const uint yiOffset = yOffset + i;
61+
62+
float amax = 0.0;
63+
for (uint j = 0; j < Q80_BLOCK_SIZE; ++j) {
64+
const float v = abs(x[xiOffset + j]);
65+
amax = max(amax, v);
66+
}
67+
68+
const float d = amax / ((1 << 7) - 1);
69+
const float id = d != 0.0 ? 1.0 / d : 0.0;
70+
71+
y[yiOffset].d = float16_t(d);
72+
73+
for (uint j = 0; j < Q80_BLOCK_SIZE; ++j) {
74+
const float v = x[xiOffset + j];
75+
y[yiOffset].qs[j] = int8_t(clamp(round(v * id), -127.0, 127.0));
76+
}
77+
}
78+
}

0 commit comments

Comments
 (0)