@@ -30,56 +30,32 @@ void llamaSyncRmsAtt(TASK_ARGS) {
30
30
syncUnitBuffer (nThreads, threadIndex, ctx, TB_UNIT_XB_QUANTIZED);
31
31
}
32
32
33
- void llamaAttQ (TASK_ARGS) {
33
+ void llamaQkv (TASK_ARGS) {
34
34
TASK_VARIABLES;
35
- float *xbq = (float *)transformer->buffer ->getUnit (TB_UNIT_XB_QUANTIZED);
36
- float *q0 = (float *)transformer->buffer ->getSliced (TB_SLICED_Q, transformer->sliceIndex );
37
- matmul (spec->weightsFloatType , spec->bufferFloatType , q0, xbq, block->q0 , block->q0Slice ->n , block->q0Slice ->d0 , nThreads, threadIndex);
38
- }
39
-
40
- void llamaQuantizeAttQ (TASK_ARGS) {
41
- TASK_VARIABLES;
42
- quantizeSlicedBuffer (nThreads, threadIndex, ctx, false , TB_SLICED_Q, TB_SLICED_Q_QUANTIZED);
43
- }
44
-
45
- void llamaSyncAttQ (TASK_ARGS) {
46
- TASK_VARIABLES;
47
- syncSliceOfSlicedBuffer (nThreads, threadIndex, ctx, TB_SLICED_Q_QUANTIZED);
48
- }
49
35
50
- void llamaAttK (TASK_ARGS) {
51
- TASK_VARIABLES;
52
36
float *xbq = (float *)transformer->buffer ->getUnit (TB_UNIT_XB_QUANTIZED);
37
+ float *q0 = (float *)transformer->buffer ->getSliced (TB_SLICED_Q, transformer->sliceIndex );
53
38
float *k0 = (float *)transformer->buffer ->getSliced (TB_SLICED_K, transformer->sliceIndex );
54
- matmul (spec->weightsFloatType , spec->bufferFloatType , k0, xbq, block->k0 , block->k0Slice ->n , block->k0Slice ->d0 , nThreads, threadIndex);
55
- }
56
-
57
- void llamaQuantizeAttK (TASK_ARGS) {
58
- TASK_VARIABLES;
59
- quantizeSlicedBuffer (nThreads, threadIndex, ctx, false , TB_SLICED_K, TB_SLICED_K_QUANTIZED);
60
- }
61
-
62
- void llamaSyncAttK (TASK_ARGS) {
63
- TASK_VARIABLES;
64
- syncSliceOfSlicedBuffer (nThreads, threadIndex, ctx, TB_SLICED_K_QUANTIZED);
65
- }
66
-
67
- void llamaAttV (TASK_ARGS) {
68
- TASK_VARIABLES;
69
- float *xbq = (float *)transformer->buffer ->getUnit (TB_UNIT_XB_QUANTIZED);
70
39
float *v0 = (float *)transformer->buffer ->getSliced (TB_SLICED_V, transformer->sliceIndex );
71
- matmul (spec->weightsFloatType , spec->bufferFloatType , v0, xbq, block->v0 , block->v0Slice ->n , block->v0Slice ->d0 , nThreads, threadIndex);
72
40
41
+ matmul (spec->weightsFloatType , spec->bufferFloatType , q0, xbq, block->q0 , block->q0Slice ->n , block->q0Slice ->d0 , nThreads, threadIndex);
42
+ matmul (spec->weightsFloatType , spec->bufferFloatType , k0, xbq, block->k0 , block->k0Slice ->n , block->k0Slice ->d0 , nThreads, threadIndex);
43
+ matmul (spec->weightsFloatType , spec->bufferFloatType , v0, xbq, block->v0 , block->v0Slice ->n , block->v0Slice ->d0 , nThreads, threadIndex);
73
44
}
74
45
75
- void llamaQuantizeAttV (TASK_ARGS) {
46
+ void llamaQuantizeQkv (TASK_ARGS) {
76
47
TASK_VARIABLES;
48
+ quantizeSlicedBuffer (nThreads, threadIndex, ctx, false , TB_SLICED_Q, TB_SLICED_Q_QUANTIZED);
49
+ quantizeSlicedBuffer (nThreads, threadIndex, ctx, false , TB_SLICED_K, TB_SLICED_K_QUANTIZED);
77
50
quantizeSlicedBuffer (nThreads, threadIndex, ctx, false , TB_SLICED_V, TB_SLICED_V_QUANTIZED);
78
51
}
79
52
80
- void llamaSyncAttV (TASK_ARGS) {
53
+ void llamaSyncQkv (TASK_ARGS) {
81
54
TASK_VARIABLES;
55
+ syncSliceOfSlicedBuffer (nThreads, threadIndex, ctx, TB_SLICED_Q_QUANTIZED);
56
+ syncSliceOfSlicedBuffer (nThreads, threadIndex, ctx, TB_SLICED_K_QUANTIZED);
82
57
syncSliceOfSlicedBuffer (nThreads, threadIndex, ctx, TB_SLICED_V_QUANTIZED);
58
+ // if (ctx->socketPool != NULL && threadIndex == 0) { float* v = (float*)block->q0; printf("q0 (%d): %f %f %f %f %f %f\n", ctx->currentBlockIndex, v[0], v[1], v[2], v[3], v[4], v[5]); }
83
59
}
84
60
85
61
void llamaDequantizeQkv (TASK_ARGS) {
@@ -316,15 +292,9 @@ TransformerArch buildLlama2Arch(TransformerSpec* spec) {
316
292
a.I (llamaRmsAttNorm, TASK_TYPE_INFERENCE);
317
293
a.I (llamaQuantizeRmsAtt, TASK_TYPE_INFERENCE);
318
294
a.I (llamaSyncRmsAtt, TASK_TYPE_TRANSFER);
319
- a.I (llamaAttQ, TASK_TYPE_INFERENCE);
320
- a.I (llamaQuantizeAttQ, TASK_TYPE_INFERENCE);
321
- a.I (llamaSyncAttQ, TASK_TYPE_TRANSFER);
322
- a.I (llamaAttK, TASK_TYPE_INFERENCE);
323
- a.I (llamaQuantizeAttK, TASK_TYPE_INFERENCE);
324
- a.I (llamaSyncAttK, TASK_TYPE_TRANSFER);
325
- a.I (llamaAttV, TASK_TYPE_INFERENCE);
326
- a.I (llamaQuantizeAttV, TASK_TYPE_INFERENCE);
327
- a.I (llamaSyncAttV, TASK_TYPE_TRANSFER);
295
+ a.I (llamaQkv, TASK_TYPE_INFERENCE);
296
+ a.I (llamaQuantizeQkv, TASK_TYPE_INFERENCE);
297
+ a.I (llamaSyncQkv, TASK_TYPE_TRANSFER);
328
298
a.I (llamaDequantizeQkv, TASK_TYPE_INFERENCE);
329
299
a.I (llamaMultiheadAtt, TASK_TYPE_INFERENCE);
330
300
a.I (llamaMultiheadAttRope, TASK_TYPE_INFERENCE);
@@ -359,15 +329,9 @@ TransformerArch buildLlama2Arch(TransformerSpec* spec) {
359
329
360
330
for (int i = 0 ; i < spec->nLayers ; i++) {
361
331
a.W (llamaSyncRmsAtt, TASK_TYPE_TRANSFER);
362
- a.W (llamaAttQ, TASK_TYPE_INFERENCE);
363
- a.W (llamaQuantizeAttQ, TASK_TYPE_INFERENCE);
364
- a.W (llamaSyncAttQ, TASK_TYPE_INFERENCE);
365
- a.W (llamaAttK, TASK_TYPE_INFERENCE);
366
- a.W (llamaQuantizeAttK, TASK_TYPE_INFERENCE);
367
- a.W (llamaSyncAttK, TASK_TYPE_INFERENCE);
368
- a.W (llamaAttV, TASK_TYPE_INFERENCE);
369
- a.W (llamaQuantizeAttV, TASK_TYPE_INFERENCE);
370
- a.W (llamaSyncAttV, TASK_TYPE_INFERENCE);
332
+ a.W (llamaQkv, TASK_TYPE_INFERENCE);
333
+ a.W (llamaQuantizeQkv, TASK_TYPE_INFERENCE);
334
+ a.W (llamaSyncQkv, TASK_TYPE_TRANSFER);
371
335
a.W (llamaSyncMultiheadAtt, TASK_TYPE_TRANSFER);
372
336
a.W (llamaAtt, TASK_TYPE_INFERENCE);
373
337
a.W (llamaQuantizeAtt, TASK_TYPE_INFERENCE);
0 commit comments