Skip to content

Commit d5b8354

Browse files
committed
sync qkv.
1 parent ad10e18 commit d5b8354

File tree

4 files changed

+99
-33
lines changed

4 files changed

+99
-33
lines changed

src/grok1-tasks.cpp

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -308,9 +308,15 @@ TransformerArch buildGrok1Arch(TransformerSpec* spec) {
308308
a.I(llamaRmsAttNorm, TASK_TYPE_INFERENCE);
309309
a.I(llamaQuantizeRmsAtt, TASK_TYPE_INFERENCE);
310310
a.I(llamaSyncRmsAtt, TASK_TYPE_TRANSFER);
311-
a.I(llamaQkv, TASK_TYPE_INFERENCE);
312-
a.I(llamaQuantizeQkv, TASK_TYPE_INFERENCE);
313-
a.I(llamaSyncQkv, TASK_TYPE_TRANSFER);
311+
a.I(llamaAttQ, TASK_TYPE_INFERENCE);
312+
a.I(llamaQuantizeAttQ, TASK_TYPE_INFERENCE);
313+
a.I(llamaSyncAttQ, TASK_TYPE_INFERENCE);
314+
a.I(llamaAttK, TASK_TYPE_INFERENCE);
315+
a.I(llamaQuantizeAttK, TASK_TYPE_INFERENCE);
316+
a.I(llamaSyncAttK, TASK_TYPE_INFERENCE);
317+
a.I(llamaAttV, TASK_TYPE_INFERENCE);
318+
a.I(llamaQuantizeAttV, TASK_TYPE_INFERENCE);
319+
a.I(llamaSyncAttV, TASK_TYPE_INFERENCE);
314320
a.I(llamaDequantizeQkv, TASK_TYPE_INFERENCE);
315321
a.I(llamaMultiheadAtt, TASK_TYPE_INFERENCE);
316322
a.I(grokMultiheadAttRope, TASK_TYPE_INFERENCE);
@@ -358,9 +364,15 @@ TransformerArch buildGrok1Arch(TransformerSpec* spec) {
358364

359365
for (int i = 0; i < spec->nLayers; i++) {
360366
a.W(llamaSyncRmsAtt, TASK_TYPE_TRANSFER);
361-
a.W(llamaQkv, TASK_TYPE_INFERENCE);
362-
a.W(llamaQuantizeQkv, TASK_TYPE_INFERENCE);
363-
a.W(llamaSyncQkv, TASK_TYPE_TRANSFER);
367+
a.W(llamaAttQ, TASK_TYPE_INFERENCE);
368+
a.W(llamaQuantizeAttQ, TASK_TYPE_INFERENCE);
369+
a.W(llamaSyncAttQ, TASK_TYPE_INFERENCE);
370+
a.W(llamaAttK, TASK_TYPE_INFERENCE);
371+
a.W(llamaQuantizeAttK, TASK_TYPE_INFERENCE);
372+
a.W(llamaSyncAttK, TASK_TYPE_INFERENCE);
373+
a.W(llamaAttV, TASK_TYPE_INFERENCE);
374+
a.W(llamaQuantizeAttV, TASK_TYPE_INFERENCE);
375+
a.W(llamaSyncAttV, TASK_TYPE_INFERENCE);
364376
a.W(llamaSyncMultiheadAtt, TASK_TYPE_TRANSFER);
365377
a.W(llamaAtt, TASK_TYPE_INFERENCE);
366378
a.W(llamaQuantizeAtt, TASK_TYPE_INFERENCE);

src/llama2-tasks.cpp

Lines changed: 54 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,32 +30,56 @@ void llamaSyncRmsAtt(TASK_ARGS) {
3030
syncUnitBuffer(nThreads, threadIndex, ctx, TB_UNIT_XB_QUANTIZED);
3131
}
3232

33-
void llamaQkv(TASK_ARGS) {
33+
void llamaAttQ(TASK_ARGS) {
3434
TASK_VARIABLES;
35-
3635
float *xbq = (float*)transformer->buffer->getUnit(TB_UNIT_XB_QUANTIZED);
3736
float *q0 = (float*)transformer->buffer->getSliced(TB_SLICED_Q, transformer->sliceIndex);
38-
float *k0 = (float*)transformer->buffer->getSliced(TB_SLICED_K, transformer->sliceIndex);
39-
float *v0 = (float*)transformer->buffer->getSliced(TB_SLICED_V, transformer->sliceIndex);
40-
4137
matmul(spec->weightsFloatType, spec->bufferFloatType, q0, xbq, block->q0, block->q0Slice->n, block->q0Slice->d0, nThreads, threadIndex);
42-
matmul(spec->weightsFloatType, spec->bufferFloatType, k0, xbq, block->k0, block->k0Slice->n, block->k0Slice->d0, nThreads, threadIndex);
43-
matmul(spec->weightsFloatType, spec->bufferFloatType, v0, xbq, block->v0, block->v0Slice->n, block->v0Slice->d0, nThreads, threadIndex);
4438
}
4539

46-
void llamaQuantizeQkv(TASK_ARGS) {
40+
void llamaQuantizeAttQ(TASK_ARGS) {
4741
TASK_VARIABLES;
4842
quantizeSlicedBuffer(nThreads, threadIndex, ctx, false, TB_SLICED_Q, TB_SLICED_Q_QUANTIZED);
49-
quantizeSlicedBuffer(nThreads, threadIndex, ctx, false, TB_SLICED_K, TB_SLICED_K_QUANTIZED);
50-
quantizeSlicedBuffer(nThreads, threadIndex, ctx, false, TB_SLICED_V, TB_SLICED_V_QUANTIZED);
5143
}
5244

53-
void llamaSyncQkv(TASK_ARGS) {
45+
void llamaSyncAttQ(TASK_ARGS) {
5446
TASK_VARIABLES;
5547
syncSliceOfSlicedBuffer(nThreads, threadIndex, ctx, TB_SLICED_Q_QUANTIZED);
48+
}
49+
50+
void llamaAttK(TASK_ARGS) {
51+
TASK_VARIABLES;
52+
float *xbq = (float*)transformer->buffer->getUnit(TB_UNIT_XB_QUANTIZED);
53+
float *k0 = (float*)transformer->buffer->getSliced(TB_SLICED_K, transformer->sliceIndex);
54+
matmul(spec->weightsFloatType, spec->bufferFloatType, k0, xbq, block->k0, block->k0Slice->n, block->k0Slice->d0, nThreads, threadIndex);
55+
}
56+
57+
void llamaQuantizeAttK(TASK_ARGS) {
58+
TASK_VARIABLES;
59+
quantizeSlicedBuffer(nThreads, threadIndex, ctx, false, TB_SLICED_K, TB_SLICED_K_QUANTIZED);
60+
}
61+
62+
void llamaSyncAttK(TASK_ARGS) {
63+
TASK_VARIABLES;
5664
syncSliceOfSlicedBuffer(nThreads, threadIndex, ctx, TB_SLICED_K_QUANTIZED);
65+
}
66+
67+
void llamaAttV(TASK_ARGS) {
68+
TASK_VARIABLES;
69+
float *xbq = (float*)transformer->buffer->getUnit(TB_UNIT_XB_QUANTIZED);
70+
float *v0 = (float*)transformer->buffer->getSliced(TB_SLICED_V, transformer->sliceIndex);
71+
matmul(spec->weightsFloatType, spec->bufferFloatType, v0, xbq, block->v0, block->v0Slice->n, block->v0Slice->d0, nThreads, threadIndex);
72+
73+
}
74+
75+
void llamaQuantizeAttV(TASK_ARGS) {
76+
TASK_VARIABLES;
77+
quantizeSlicedBuffer(nThreads, threadIndex, ctx, false, TB_SLICED_V, TB_SLICED_V_QUANTIZED);
78+
}
79+
80+
void llamaSyncAttV(TASK_ARGS) {
81+
TASK_VARIABLES;
5782
syncSliceOfSlicedBuffer(nThreads, threadIndex, ctx, TB_SLICED_V_QUANTIZED);
58-
// if (ctx->socketPool != NULL && threadIndex == 0) { float* v = (float*)block->q0; printf("q0 (%d): %f %f %f %f %f %f\n", ctx->currentBlockIndex, v[0], v[1], v[2], v[3], v[4], v[5]); }
5983
}
6084

6185
void llamaDequantizeQkv(TASK_ARGS) {
@@ -292,9 +316,15 @@ TransformerArch buildLlama2Arch(TransformerSpec* spec) {
292316
a.I(llamaRmsAttNorm, TASK_TYPE_INFERENCE);
293317
a.I(llamaQuantizeRmsAtt, TASK_TYPE_INFERENCE);
294318
a.I(llamaSyncRmsAtt, TASK_TYPE_TRANSFER);
295-
a.I(llamaQkv, TASK_TYPE_INFERENCE);
296-
a.I(llamaQuantizeQkv, TASK_TYPE_INFERENCE);
297-
a.I(llamaSyncQkv, TASK_TYPE_TRANSFER);
319+
a.I(llamaAttQ, TASK_TYPE_INFERENCE);
320+
a.I(llamaQuantizeAttQ, TASK_TYPE_INFERENCE);
321+
a.I(llamaSyncAttQ, TASK_TYPE_INFERENCE);
322+
a.I(llamaAttK, TASK_TYPE_INFERENCE);
323+
a.I(llamaQuantizeAttK, TASK_TYPE_INFERENCE);
324+
a.I(llamaSyncAttK, TASK_TYPE_INFERENCE);
325+
a.I(llamaAttV, TASK_TYPE_INFERENCE);
326+
a.I(llamaQuantizeAttV, TASK_TYPE_INFERENCE);
327+
a.I(llamaSyncAttV, TASK_TYPE_INFERENCE);
298328
a.I(llamaDequantizeQkv, TASK_TYPE_INFERENCE);
299329
a.I(llamaMultiheadAtt, TASK_TYPE_INFERENCE);
300330
a.I(llamaMultiheadAttRope, TASK_TYPE_INFERENCE);
@@ -329,9 +359,15 @@ TransformerArch buildLlama2Arch(TransformerSpec* spec) {
329359

330360
for (int i = 0; i < spec->nLayers; i++) {
331361
a.W(llamaSyncRmsAtt, TASK_TYPE_TRANSFER);
332-
a.W(llamaQkv, TASK_TYPE_INFERENCE);
333-
a.W(llamaQuantizeQkv, TASK_TYPE_INFERENCE);
334-
a.W(llamaSyncQkv, TASK_TYPE_TRANSFER);
362+
a.W(llamaAttQ, TASK_TYPE_INFERENCE);
363+
a.W(llamaQuantizeAttQ, TASK_TYPE_INFERENCE);
364+
a.W(llamaSyncAttQ, TASK_TYPE_INFERENCE);
365+
a.W(llamaAttK, TASK_TYPE_INFERENCE);
366+
a.W(llamaQuantizeAttK, TASK_TYPE_INFERENCE);
367+
a.W(llamaSyncAttK, TASK_TYPE_INFERENCE);
368+
a.W(llamaAttV, TASK_TYPE_INFERENCE);
369+
a.W(llamaQuantizeAttV, TASK_TYPE_INFERENCE);
370+
a.W(llamaSyncAttV, TASK_TYPE_INFERENCE);
335371
a.W(llamaSyncMultiheadAtt, TASK_TYPE_TRANSFER);
336372
a.W(llamaAtt, TASK_TYPE_INFERENCE);
337373
a.W(llamaQuantizeAtt, TASK_TYPE_INFERENCE);

src/llama2-tasks.hpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,15 @@ void llamaRmsAtt(TASK_ARGS);
77
void llamaRmsAttNorm(TASK_ARGS);
88
void llamaQuantizeRmsAtt(TASK_ARGS);
99
void llamaSyncRmsAtt(TASK_ARGS);
10-
void llamaQkv(TASK_ARGS);
11-
void llamaQuantizeQkv(TASK_ARGS);
12-
void llamaSyncQkv(TASK_ARGS);
10+
void llamaAttQ(TASK_ARGS);
11+
void llamaQuantizeAttQ(TASK_ARGS);
12+
void llamaSyncAttQ(TASK_ARGS);
13+
void llamaAttK(TASK_ARGS);
14+
void llamaQuantizeAttK(TASK_ARGS);
15+
void llamaSyncAttK(TASK_ARGS);
16+
void llamaAttV(TASK_ARGS);
17+
void llamaQuantizeAttV(TASK_ARGS);
18+
void llamaSyncAttV(TASK_ARGS);
1319
void llamaDequantizeQkv(TASK_ARGS);
1420
void llamaMultiheadAtt(TASK_ARGS);
1521
void llamaMultiheadAttRope(TASK_ARGS);

src/mixtral-tasks.cpp

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,15 @@ TransformerArch buildMixtralArch(TransformerSpec* spec) {
1313
a.I(llamaRmsAttNorm, TASK_TYPE_INFERENCE);
1414
a.I(llamaQuantizeRmsAtt, TASK_TYPE_INFERENCE);
1515
a.I(llamaSyncRmsAtt, TASK_TYPE_TRANSFER);
16-
a.I(llamaQkv, TASK_TYPE_INFERENCE);
17-
a.I(llamaQuantizeQkv, TASK_TYPE_INFERENCE);
18-
a.I(llamaSyncQkv, TASK_TYPE_TRANSFER);
16+
a.I(llamaAttQ, TASK_TYPE_INFERENCE);
17+
a.I(llamaQuantizeAttQ, TASK_TYPE_INFERENCE);
18+
a.I(llamaSyncAttQ, TASK_TYPE_INFERENCE);
19+
a.I(llamaAttK, TASK_TYPE_INFERENCE);
20+
a.I(llamaQuantizeAttK, TASK_TYPE_INFERENCE);
21+
a.I(llamaSyncAttK, TASK_TYPE_INFERENCE);
22+
a.I(llamaAttV, TASK_TYPE_INFERENCE);
23+
a.I(llamaQuantizeAttV, TASK_TYPE_INFERENCE);
24+
a.I(llamaSyncAttV, TASK_TYPE_INFERENCE);
1925
a.I(llamaDequantizeQkv, TASK_TYPE_INFERENCE);
2026
a.I(llamaMultiheadAtt, TASK_TYPE_INFERENCE);
2127
a.I(grokMultiheadAttRope, TASK_TYPE_INFERENCE);
@@ -58,9 +64,15 @@ TransformerArch buildMixtralArch(TransformerSpec* spec) {
5864

5965
for (int i = 0; i < spec->nLayers; i++) {
6066
a.W(llamaSyncRmsAtt, TASK_TYPE_TRANSFER);
61-
a.W(llamaQkv, TASK_TYPE_INFERENCE);
62-
a.W(llamaQuantizeQkv, TASK_TYPE_INFERENCE);
63-
a.W(llamaSyncQkv, TASK_TYPE_TRANSFER);
67+
a.W(llamaAttQ, TASK_TYPE_INFERENCE);
68+
a.W(llamaQuantizeAttQ, TASK_TYPE_INFERENCE);
69+
a.W(llamaSyncAttQ, TASK_TYPE_INFERENCE);
70+
a.W(llamaAttK, TASK_TYPE_INFERENCE);
71+
a.W(llamaQuantizeAttK, TASK_TYPE_INFERENCE);
72+
a.W(llamaSyncAttK, TASK_TYPE_INFERENCE);
73+
a.W(llamaAttV, TASK_TYPE_INFERENCE);
74+
a.W(llamaQuantizeAttV, TASK_TYPE_INFERENCE);
75+
a.W(llamaSyncAttV, TASK_TYPE_INFERENCE);
6476
a.W(llamaSyncMultiheadAtt, TASK_TYPE_TRANSFER);
6577
a.W(llamaAtt, TASK_TYPE_INFERENCE);
6678
a.W(llamaQuantizeAtt, TASK_TYPE_INFERENCE);

0 commit comments

Comments
 (0)