@@ -30,32 +30,56 @@ void llamaSyncRmsAtt(TASK_ARGS) {
30
30
syncUnitBuffer (nThreads, threadIndex, ctx, TB_UNIT_XB_QUANTIZED);
31
31
}
32
32
33
- void llamaQkv (TASK_ARGS) {
33
+ void llamaAttQ (TASK_ARGS) {
34
34
TASK_VARIABLES;
35
-
36
35
float *xbq = (float *)transformer->buffer ->getUnit (TB_UNIT_XB_QUANTIZED);
37
36
float *q0 = (float *)transformer->buffer ->getSliced (TB_SLICED_Q, transformer->sliceIndex );
38
- float *k0 = (float *)transformer->buffer ->getSliced (TB_SLICED_K, transformer->sliceIndex );
39
- float *v0 = (float *)transformer->buffer ->getSliced (TB_SLICED_V, transformer->sliceIndex );
40
-
41
37
matmul (spec->weightsFloatType , spec->bufferFloatType , q0, xbq, block->q0 , block->q0Slice ->n , block->q0Slice ->d0 , nThreads, threadIndex);
42
- matmul (spec->weightsFloatType , spec->bufferFloatType , k0, xbq, block->k0 , block->k0Slice ->n , block->k0Slice ->d0 , nThreads, threadIndex);
43
- matmul (spec->weightsFloatType , spec->bufferFloatType , v0, xbq, block->v0 , block->v0Slice ->n , block->v0Slice ->d0 , nThreads, threadIndex);
44
38
}
45
39
46
- void llamaQuantizeQkv (TASK_ARGS) {
40
+ void llamaQuantizeAttQ (TASK_ARGS) {
47
41
TASK_VARIABLES;
48
42
quantizeSlicedBuffer (nThreads, threadIndex, ctx, false , TB_SLICED_Q, TB_SLICED_Q_QUANTIZED);
49
- quantizeSlicedBuffer (nThreads, threadIndex, ctx, false , TB_SLICED_K, TB_SLICED_K_QUANTIZED);
50
- quantizeSlicedBuffer (nThreads, threadIndex, ctx, false , TB_SLICED_V, TB_SLICED_V_QUANTIZED);
51
43
}
52
44
53
- void llamaSyncQkv (TASK_ARGS) {
45
+ void llamaSyncAttQ (TASK_ARGS) {
54
46
TASK_VARIABLES;
55
47
syncSliceOfSlicedBuffer (nThreads, threadIndex, ctx, TB_SLICED_Q_QUANTIZED);
48
+ }
49
+
50
+ void llamaAttK (TASK_ARGS) {
51
+ TASK_VARIABLES;
52
+ float *xbq = (float *)transformer->buffer ->getUnit (TB_UNIT_XB_QUANTIZED);
53
+ float *k0 = (float *)transformer->buffer ->getSliced (TB_SLICED_K, transformer->sliceIndex );
54
+ matmul (spec->weightsFloatType , spec->bufferFloatType , k0, xbq, block->k0 , block->k0Slice ->n , block->k0Slice ->d0 , nThreads, threadIndex);
55
+ }
56
+
57
+ void llamaQuantizeAttK (TASK_ARGS) {
58
+ TASK_VARIABLES;
59
+ quantizeSlicedBuffer (nThreads, threadIndex, ctx, false , TB_SLICED_K, TB_SLICED_K_QUANTIZED);
60
+ }
61
+
62
+ void llamaSyncAttK (TASK_ARGS) {
63
+ TASK_VARIABLES;
56
64
syncSliceOfSlicedBuffer (nThreads, threadIndex, ctx, TB_SLICED_K_QUANTIZED);
65
+ }
66
+
67
+ void llamaAttV (TASK_ARGS) {
68
+ TASK_VARIABLES;
69
+ float *xbq = (float *)transformer->buffer ->getUnit (TB_UNIT_XB_QUANTIZED);
70
+ float *v0 = (float *)transformer->buffer ->getSliced (TB_SLICED_V, transformer->sliceIndex );
71
+ matmul (spec->weightsFloatType , spec->bufferFloatType , v0, xbq, block->v0 , block->v0Slice ->n , block->v0Slice ->d0 , nThreads, threadIndex);
72
+
73
+ }
74
+
75
+ void llamaQuantizeAttV (TASK_ARGS) {
76
+ TASK_VARIABLES;
77
+ quantizeSlicedBuffer (nThreads, threadIndex, ctx, false , TB_SLICED_V, TB_SLICED_V_QUANTIZED);
78
+ }
79
+
80
+ void llamaSyncAttV (TASK_ARGS) {
81
+ TASK_VARIABLES;
57
82
syncSliceOfSlicedBuffer (nThreads, threadIndex, ctx, TB_SLICED_V_QUANTIZED);
58
- // if (ctx->socketPool != NULL && threadIndex == 0) { float* v = (float*)block->q0; printf("q0 (%d): %f %f %f %f %f %f\n", ctx->currentBlockIndex, v[0], v[1], v[2], v[3], v[4], v[5]); }
59
83
}
60
84
61
85
void llamaDequantizeQkv (TASK_ARGS) {
@@ -292,9 +316,15 @@ TransformerArch buildLlama2Arch(TransformerSpec* spec) {
292
316
a.I (llamaRmsAttNorm, TASK_TYPE_INFERENCE);
293
317
a.I (llamaQuantizeRmsAtt, TASK_TYPE_INFERENCE);
294
318
a.I (llamaSyncRmsAtt, TASK_TYPE_TRANSFER);
295
- a.I (llamaQkv, TASK_TYPE_INFERENCE);
296
- a.I (llamaQuantizeQkv, TASK_TYPE_INFERENCE);
297
- a.I (llamaSyncQkv, TASK_TYPE_TRANSFER);
319
+ a.I (llamaAttQ, TASK_TYPE_INFERENCE);
320
+ a.I (llamaQuantizeAttQ, TASK_TYPE_INFERENCE);
321
+ a.I (llamaSyncAttQ, TASK_TYPE_INFERENCE);
322
+ a.I (llamaAttK, TASK_TYPE_INFERENCE);
323
+ a.I (llamaQuantizeAttK, TASK_TYPE_INFERENCE);
324
+ a.I (llamaSyncAttK, TASK_TYPE_INFERENCE);
325
+ a.I (llamaAttV, TASK_TYPE_INFERENCE);
326
+ a.I (llamaQuantizeAttV, TASK_TYPE_INFERENCE);
327
+ a.I (llamaSyncAttV, TASK_TYPE_INFERENCE);
298
328
a.I (llamaDequantizeQkv, TASK_TYPE_INFERENCE);
299
329
a.I (llamaMultiheadAtt, TASK_TYPE_INFERENCE);
300
330
a.I (llamaMultiheadAttRope, TASK_TYPE_INFERENCE);
@@ -329,9 +359,15 @@ TransformerArch buildLlama2Arch(TransformerSpec* spec) {
329
359
330
360
for (int i = 0 ; i < spec->nLayers ; i++) {
331
361
a.W (llamaSyncRmsAtt, TASK_TYPE_TRANSFER);
332
- a.W (llamaQkv, TASK_TYPE_INFERENCE);
333
- a.W (llamaQuantizeQkv, TASK_TYPE_INFERENCE);
334
- a.W (llamaSyncQkv, TASK_TYPE_TRANSFER);
362
+ a.W (llamaAttQ, TASK_TYPE_INFERENCE);
363
+ a.W (llamaQuantizeAttQ, TASK_TYPE_INFERENCE);
364
+ a.W (llamaSyncAttQ, TASK_TYPE_INFERENCE);
365
+ a.W (llamaAttK, TASK_TYPE_INFERENCE);
366
+ a.W (llamaQuantizeAttK, TASK_TYPE_INFERENCE);
367
+ a.W (llamaSyncAttK, TASK_TYPE_INFERENCE);
368
+ a.W (llamaAttV, TASK_TYPE_INFERENCE);
369
+ a.W (llamaQuantizeAttV, TASK_TYPE_INFERENCE);
370
+ a.W (llamaSyncAttV, TASK_TYPE_INFERENCE);
335
371
a.W (llamaSyncMultiheadAtt, TASK_TYPE_TRANSFER);
336
372
a.W (llamaAtt, TASK_TYPE_INFERENCE);
337
373
a.W (llamaQuantizeAtt, TASK_TYPE_INFERENCE);
0 commit comments