Skip to content

Commit a0f2803

Browse files
committed
testMatmul_Q80_Q40_F32.
1 parent bee26bb commit a0f2803

5 files changed

+165
-21
lines changed

src/nn/nn-cpu-ops.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,7 @@ static void matmul_Q80_Q40_F32(float *output, const NnBlockQ80 *x, const NnBlock
430430
const int w0 = (wb->qs[k] & 0x0F) - 8;
431431
const int w1 = (wb->qs[k] >> 4) - 8;
432432
const int i1 = xb->qs[k];
433-
const int i2 = xb->qs[k + Q40_BLOCK_SIZE / 2];
433+
const int i2 = xb->qs[k + Q80_BLOCK_SIZE / 2];
434434
sum += (w0 * i1 + w1 * i2) * s;
435435
}
436436
}

src/nn/nn-vulkan-test.cpp

+76-19
Original file line numberDiff line numberDiff line change
@@ -483,18 +483,18 @@ void testRope_F32_F32() {
483483
});
484484
}
485485

486-
void matmul_F32_F32_F32() {
487-
#define MATMUL_N 64
488-
#define MATMUL_D 96
486+
void testMatmul_F32_F32_F32() {
487+
#define MATMUL_F32_N 64
488+
#define MATMUL_F32_D 96
489489
execute(
490490
[](NnNetConfigBuilder *netBuilder, NnNodeConfigBuilder *nodeBuilder, NnSegmentConfigBuilder *segmentBuilder) {
491-
NnUint xPipeIndex = netBuilder->addPipe("X", size2D(F_32, N_BATCHES, MATMUL_N));
492-
NnUint yPipeIndex = netBuilder->addPipe("Y", size2D(F_32, N_BATCHES, MATMUL_D));
491+
NnUint xPipeIndex = netBuilder->addPipe("X", size2D(F_32, N_BATCHES, MATMUL_F32_N));
492+
NnUint yPipeIndex = netBuilder->addPipe("Y", size2D(F_32, N_BATCHES, MATMUL_F32_D));
493493
segmentBuilder->addOp(
494494
OP_MATMUL, "matmul", 0,
495495
pointerBatchConfig(SRC_PIPE, xPipeIndex),
496496
pointerBatchConfig(SRC_PIPE, yPipeIndex),
497-
size2D(F_32, MATMUL_N, MATMUL_D),
497+
size2D(F_32, MATMUL_F32_N, MATMUL_F32_D),
498498
NnMatmulOpConfig{});
499499
},
500500
[](NnExecutor *executor, NnNetExecution *execution, NnVulkanDevice *device) {
@@ -503,32 +503,88 @@ void matmul_F32_F32_F32() {
503503
float *xPipe = (float *)execution->pipes[0];
504504
float *yPipe = (float *)execution->pipes[1];
505505

506-
float weight[MATMUL_N * MATMUL_D];
507-
for (NnUint i = 0; i < N_BATCHES * MATMUL_N; i++)
506+
float weight[MATMUL_F32_N * MATMUL_F32_D];
507+
for (NnUint i = 0; i < N_BATCHES * MATMUL_F32_N; i++)
508508
xPipe[i] = i * 0.01f;
509-
for (NnUint i = 0; i < MATMUL_N * MATMUL_D; i++)
509+
for (NnUint i = 0; i < MATMUL_F32_N * MATMUL_F32_D; i++)
510510
weight[i] = i * 0.001f;
511-
executor->loadWeight("matmul", 0, MATMUL_N * MATMUL_D * sizeof(float), (NnByte *)weight);
511+
executor->loadWeight("matmul", 0, MATMUL_F32_N * MATMUL_F32_D * sizeof(float), (NnByte *)weight);
512512

513513
// act
514514
executor->forward();
515515

516516
// assert
517517
for (NnUint b = 0; b < N_BATCHES; b++) {
518-
for (NnUint d = 0; d < MATMUL_D; d++) {
518+
for (NnUint d = 0; d < MATMUL_F32_D; d++) {
519519
float sum = 0.0f;
520-
for (NnUint n = 0; n < MATMUL_N; n++)
521-
sum += xPipe[b * MATMUL_N + n] * weight[d * MATMUL_N + n];
520+
for (NnUint n = 0; n < MATMUL_F32_N; n++)
521+
sum += xPipe[b * MATMUL_F32_N + n] * weight[d * MATMUL_F32_N + n];
522522

523-
const NnUint p = b * MATMUL_D + d;
523+
const NnUint p = b * MATMUL_F32_D + d;
524524
assertFloat(p, yPipe[p], sum, 0.0002f);
525525
}
526526
}
527-
printOk("matmul_F32_F32_F32");
527+
printOk("testMatmul_F32_F32_F32");
528528
});
529529
}
530530

531-
void multiheadAtt_F32_F32() {
531+
void testMatmul_Q80_Q40_F32() {
532+
#define MATMUL_Q80_Q40_N 64
533+
#define MATMUL_Q80_Q40_D 96
534+
execute(
535+
[](NnNetConfigBuilder *netBuilder, NnNodeConfigBuilder *nodeBuilder, NnSegmentConfigBuilder *segmentBuilder) {
536+
NnUint xPipeIndex = netBuilder->addPipe("X", size2D(F_Q80, N_BATCHES, MATMUL_Q80_Q40_N));
537+
NnUint yPipeIndex = netBuilder->addPipe("Y", size2D(F_32, N_BATCHES, MATMUL_Q80_Q40_D));
538+
segmentBuilder->addOp(
539+
OP_MATMUL, "matmul", 0,
540+
pointerBatchConfig(SRC_PIPE, xPipeIndex),
541+
pointerBatchConfig(SRC_PIPE, yPipeIndex),
542+
size2D(F_Q40, MATMUL_Q80_Q40_N, MATMUL_Q80_Q40_D),
543+
NnMatmulOpConfig{});
544+
},
545+
[](NnExecutor *executor, NnNetExecution *execution, NnVulkanDevice *device) {
546+
// arrange
547+
execution->setBatchSize(N_BATCHES);
548+
NnBlockQ80 *xPipe = (NnBlockQ80 *)execution->pipes[0];
549+
float *yPipe = (float *)execution->pipes[1];
550+
551+
constexpr NnUint xSize = N_BATCHES * MATMUL_Q80_Q40_N;
552+
constexpr NnUint weightSize = MATMUL_Q80_Q40_N * MATMUL_Q80_Q40_D;
553+
constexpr NnUint weightBlocks = weightSize / Q40_BLOCK_SIZE;
554+
555+
float x[xSize];
556+
float weight[weightSize];
557+
NnBlockQ40 weightQ40[weightBlocks];
558+
559+
for (NnUint i = 0; i < xSize; i++)
560+
x[i] = i * 0.01f;
561+
for (NnUint i = 0; i < weightSize; i++)
562+
weight[i] = i * 0.001f;
563+
564+
quantizeF32toQ80(x, xPipe, xSize, 1, 0);
565+
quantizeF32toQ40(weight, weightQ40, weightSize, 1, 0);
566+
567+
executor->loadWeight("matmul", 0, weightBlocks * sizeof(NnBlockQ40), (NnByte *)weightQ40);
568+
569+
// act
570+
executor->forward();
571+
572+
// assert
573+
for (NnUint b = 0; b < N_BATCHES; b++) {
574+
for (NnUint d = 0; d < MATMUL_Q80_Q40_D; d++) {
575+
float sum = 0.0f;
576+
for (NnUint n = 0; n < MATMUL_Q80_Q40_N; n++)
577+
sum += x[b * MATMUL_Q80_Q40_N + n] * weight[d * MATMUL_Q80_Q40_N + n];
578+
const NnUint p = b * MATMUL_Q80_Q40_D + d;
579+
const float change = (yPipe[p] - sum) / sum;
580+
assertFloat(p, change, 0.0, 0.04f);
581+
}
582+
}
583+
printOk("testMatmul_Q80_Q40_F32");
584+
});
585+
}
586+
587+
void testMultiheadAtt_F32_F32() {
532588
#define MULTIHEAD_ATT_DIM 128
533589
execute(
534590
[](NnNetConfigBuilder *netBuilder, NnNodeConfigBuilder *nodeBuilder, NnSegmentConfigBuilder *segmentBuilder) {
@@ -560,7 +616,7 @@ void multiheadAtt_F32_F32() {
560616
// TODO: for now this is a smoke test
561617
execution->setBatchSize(N_BATCHES);
562618
executor->forward();
563-
printOk("multiheadAtt_F32_F32");
619+
printOk("testMultiheadAtt_F32_F32");
564620
});
565621
}
566622

@@ -577,7 +633,8 @@ int main() {
577633
testCast_F32_F32();
578634
testCast_F32_Q80();
579635
testRope_F32_F32();
580-
matmul_F32_F32_F32();
581-
multiheadAtt_F32_F32();
636+
testMatmul_F32_F32_F32();
637+
testMatmul_Q80_Q40_F32();
638+
testMultiheadAtt_F32_F32();
582639
return 0;
583640
}

src/nn/nn-vulkan.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,7 @@ static const char *getShaderFileName(const NnOpCode opCode, const NnOpQuantType
388388
}
389389
if (opCode == OP_MATMUL) {
390390
if (quantType == F32_F32_F32) return "matmul-forward-f32-f32-f32.spv";
391+
if (quantType == Q80_Q40_F32) return "matmul-forward-q80-q40-f32.spv";
391392
}
392393
if (opCode == OP_MULTIHEAD_ATT) {
393394
if (quantType == F32_F32_F32) return "multi-head-att-forward-f32-f32.spv";
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
#version 450
2+
3+
#extension GL_EXT_control_flow_attributes : enable
4+
#extension GL_EXT_shader_16bit_storage : enable
5+
#extension GL_EXT_shader_explicit_arithmetic_types : enable
6+
7+
#define Q80_BLOCK_SIZE 32
8+
#define Q40_BLOCK_SIZE 32
9+
#define N_THREADS 128
10+
11+
layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in;
12+
13+
struct BatchInfo {
14+
uint inputOffset;
15+
uint inputSizeX;
16+
uint outputOffset;
17+
uint outputSizeX;
18+
};
19+
20+
struct BlockQ80 {
21+
float16_t d;
22+
int8_t qs[Q80_BLOCK_SIZE];
23+
};
24+
25+
struct BlockQ40 {
26+
float16_t d;
27+
uint8_t qs[Q40_BLOCK_SIZE / 2];
28+
};
29+
30+
layout(binding = 0) readonly buffer inputBuffer { BlockQ80 x[]; };
31+
layout(binding = 1) writeonly buffer outputBuffer { float y[]; };
32+
layout(binding = 2) readonly buffer batchInfosBuffer { BatchInfo infos[]; };
33+
layout(binding = 3) readonly buffer weightBuffer { BlockQ40 weight[]; };
34+
35+
shared uint sharedStart;
36+
shared uint sharedEnd;
37+
shared BatchInfo sharedInfo;
38+
39+
void main() {
40+
const uint threadIndex = gl_LocalInvocationID.x;
41+
42+
if (threadIndex == 0) {
43+
const uint nWorkGroups = gl_NumWorkGroups.z;
44+
const uint batchIndex = gl_WorkGroupID.y;
45+
const uint workGroupIndex = gl_WorkGroupID.z;
46+
47+
const BatchInfo info = infos[batchIndex];
48+
49+
const uint ySlice = info.outputSizeX / nWorkGroups;
50+
const uint yRest = info.outputSizeX % nWorkGroups;
51+
sharedStart = workGroupIndex * ySlice + (workGroupIndex < yRest ? workGroupIndex : yRest);
52+
sharedEnd = sharedStart + ySlice + (workGroupIndex < yRest ? 1 : 0);
53+
sharedInfo = info;
54+
}
55+
56+
barrier();
57+
memoryBarrierShared();
58+
59+
const uint end = sharedEnd;
60+
const uint inputOffset = sharedInfo.inputOffset;
61+
const uint inputSizeX = sharedInfo.inputSizeX;
62+
const uint outputOffset = sharedInfo.outputOffset;
63+
64+
for (uint d = sharedStart + threadIndex; d < end; d += N_THREADS) {
65+
float16_t sum = float16_t(0.0f);
66+
67+
for (uint i = 0; i < inputSizeX; i++) {
68+
const BlockQ80 xi = x[inputOffset + i];
69+
const BlockQ40 wi = weight[d * inputSizeX + i];
70+
71+
float16_t s = float16_t(0.0f);
72+
[[unroll]] for (uint j = 0; j < Q40_BLOCK_SIZE / 2; j++) {
73+
const float16_t x0 = float16_t(xi.qs[j]);
74+
const float16_t x1 = float16_t(xi.qs[j + Q80_BLOCK_SIZE / 2]);
75+
76+
const uint8_t wq = wi.qs[j];
77+
const float16_t w0 = float16_t(wq & 0xF) - float16_t(8.0f);
78+
const float16_t w1 = float16_t(wq >> 4) - float16_t(8.0f);
79+
s += x0 * w0 + x1 * w1;
80+
}
81+
sum += s * xi.d * wi.d;
82+
}
83+
84+
y[outputOffset + d] = float(sum);
85+
}
86+
}

src/nn/vulkan/merge-add-forward-f32-f32.comp

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ struct BatchInfo {
1212
};
1313

1414
layout(binding = 0) readonly buffer inputBuffer { float x[]; };
15-
layout(binding = 1) writeonly buffer outputBuffer { float y[]; };
15+
layout(binding = 1) buffer outputBuffer { float y[]; };
1616
layout(binding = 2) readonly buffer batchInfosBuffer { BatchInfo infos[]; };
1717

1818
shared uint sharedDim;

0 commit comments

Comments
 (0)