Skip to content

Commit b2f3450

Browse files
authored
qkv. (#36)
1 parent 7f63f9e commit b2f3450

File tree

4 files changed

+33
-99
lines changed

4 files changed

+33
-99
lines changed

src/grok1-tasks.cpp

+6-18
Original file line numberDiff line numberDiff line change
@@ -308,15 +308,9 @@ TransformerArch buildGrok1Arch(TransformerSpec* spec) {
308308
a.I(llamaRmsAttNorm, TASK_TYPE_INFERENCE);
309309
a.I(llamaQuantizeRmsAtt, TASK_TYPE_INFERENCE);
310310
a.I(llamaSyncRmsAtt, TASK_TYPE_TRANSFER);
311-
a.I(llamaAttQ, TASK_TYPE_INFERENCE);
312-
a.I(llamaQuantizeAttQ, TASK_TYPE_INFERENCE);
313-
a.I(llamaSyncAttQ, TASK_TYPE_TRANSFER);
314-
a.I(llamaAttK, TASK_TYPE_INFERENCE);
315-
a.I(llamaQuantizeAttK, TASK_TYPE_INFERENCE);
316-
a.I(llamaSyncAttK, TASK_TYPE_TRANSFER);
317-
a.I(llamaAttV, TASK_TYPE_INFERENCE);
318-
a.I(llamaQuantizeAttV, TASK_TYPE_INFERENCE);
319-
a.I(llamaSyncAttV, TASK_TYPE_TRANSFER);
311+
a.I(llamaQkv, TASK_TYPE_INFERENCE);
312+
a.I(llamaQuantizeQkv, TASK_TYPE_INFERENCE);
313+
a.I(llamaSyncQkv, TASK_TYPE_TRANSFER);
320314
a.I(llamaDequantizeQkv, TASK_TYPE_INFERENCE);
321315
a.I(llamaMultiheadAtt, TASK_TYPE_INFERENCE);
322316
a.I(grokMultiheadAttRope, TASK_TYPE_INFERENCE);
@@ -364,15 +358,9 @@ TransformerArch buildGrok1Arch(TransformerSpec* spec) {
364358

365359
for (int i = 0; i < spec->nLayers; i++) {
366360
a.W(llamaSyncRmsAtt, TASK_TYPE_TRANSFER);
367-
a.W(llamaAttQ, TASK_TYPE_INFERENCE);
368-
a.W(llamaQuantizeAttQ, TASK_TYPE_INFERENCE);
369-
a.W(llamaSyncAttQ, TASK_TYPE_INFERENCE);
370-
a.W(llamaAttK, TASK_TYPE_INFERENCE);
371-
a.W(llamaQuantizeAttK, TASK_TYPE_INFERENCE);
372-
a.W(llamaSyncAttK, TASK_TYPE_INFERENCE);
373-
a.W(llamaAttV, TASK_TYPE_INFERENCE);
374-
a.W(llamaQuantizeAttV, TASK_TYPE_INFERENCE);
375-
a.W(llamaSyncAttV, TASK_TYPE_INFERENCE);
361+
a.W(llamaQkv, TASK_TYPE_INFERENCE);
362+
a.W(llamaQuantizeQkv, TASK_TYPE_INFERENCE);
363+
a.W(llamaSyncQkv, TASK_TYPE_TRANSFER);
376364
a.W(llamaSyncMultiheadAtt, TASK_TYPE_TRANSFER);
377365
a.W(llamaAtt, TASK_TYPE_INFERENCE);
378366
a.W(llamaQuantizeAtt, TASK_TYPE_INFERENCE);

src/llama2-tasks.cpp

+18-54
Original file line numberDiff line numberDiff line change
@@ -30,56 +30,32 @@ void llamaSyncRmsAtt(TASK_ARGS) {
3030
syncUnitBuffer(nThreads, threadIndex, ctx, TB_UNIT_XB_QUANTIZED);
3131
}
3232

33-
void llamaAttQ(TASK_ARGS) {
33+
void llamaQkv(TASK_ARGS) {
3434
TASK_VARIABLES;
35-
float *xbq = (float*)transformer->buffer->getUnit(TB_UNIT_XB_QUANTIZED);
36-
float *q0 = (float*)transformer->buffer->getSliced(TB_SLICED_Q, transformer->sliceIndex);
37-
matmul(spec->weightsFloatType, spec->bufferFloatType, q0, xbq, block->q0, block->q0Slice->n, block->q0Slice->d0, nThreads, threadIndex);
38-
}
39-
40-
void llamaQuantizeAttQ(TASK_ARGS) {
41-
TASK_VARIABLES;
42-
quantizeSlicedBuffer(nThreads, threadIndex, ctx, false, TB_SLICED_Q, TB_SLICED_Q_QUANTIZED);
43-
}
44-
45-
void llamaSyncAttQ(TASK_ARGS) {
46-
TASK_VARIABLES;
47-
syncSliceOfSlicedBuffer(nThreads, threadIndex, ctx, TB_SLICED_Q_QUANTIZED);
48-
}
4935

50-
void llamaAttK(TASK_ARGS) {
51-
TASK_VARIABLES;
5236
float *xbq = (float*)transformer->buffer->getUnit(TB_UNIT_XB_QUANTIZED);
37+
float *q0 = (float*)transformer->buffer->getSliced(TB_SLICED_Q, transformer->sliceIndex);
5338
float *k0 = (float*)transformer->buffer->getSliced(TB_SLICED_K, transformer->sliceIndex);
54-
matmul(spec->weightsFloatType, spec->bufferFloatType, k0, xbq, block->k0, block->k0Slice->n, block->k0Slice->d0, nThreads, threadIndex);
55-
}
56-
57-
void llamaQuantizeAttK(TASK_ARGS) {
58-
TASK_VARIABLES;
59-
quantizeSlicedBuffer(nThreads, threadIndex, ctx, false, TB_SLICED_K, TB_SLICED_K_QUANTIZED);
60-
}
61-
62-
void llamaSyncAttK(TASK_ARGS) {
63-
TASK_VARIABLES;
64-
syncSliceOfSlicedBuffer(nThreads, threadIndex, ctx, TB_SLICED_K_QUANTIZED);
65-
}
66-
67-
void llamaAttV(TASK_ARGS) {
68-
TASK_VARIABLES;
69-
float *xbq = (float*)transformer->buffer->getUnit(TB_UNIT_XB_QUANTIZED);
7039
float *v0 = (float*)transformer->buffer->getSliced(TB_SLICED_V, transformer->sliceIndex);
71-
matmul(spec->weightsFloatType, spec->bufferFloatType, v0, xbq, block->v0, block->v0Slice->n, block->v0Slice->d0, nThreads, threadIndex);
7240

41+
matmul(spec->weightsFloatType, spec->bufferFloatType, q0, xbq, block->q0, block->q0Slice->n, block->q0Slice->d0, nThreads, threadIndex);
42+
matmul(spec->weightsFloatType, spec->bufferFloatType, k0, xbq, block->k0, block->k0Slice->n, block->k0Slice->d0, nThreads, threadIndex);
43+
matmul(spec->weightsFloatType, spec->bufferFloatType, v0, xbq, block->v0, block->v0Slice->n, block->v0Slice->d0, nThreads, threadIndex);
7344
}
7445

75-
void llamaQuantizeAttV(TASK_ARGS) {
46+
void llamaQuantizeQkv(TASK_ARGS) {
7647
TASK_VARIABLES;
48+
quantizeSlicedBuffer(nThreads, threadIndex, ctx, false, TB_SLICED_Q, TB_SLICED_Q_QUANTIZED);
49+
quantizeSlicedBuffer(nThreads, threadIndex, ctx, false, TB_SLICED_K, TB_SLICED_K_QUANTIZED);
7750
quantizeSlicedBuffer(nThreads, threadIndex, ctx, false, TB_SLICED_V, TB_SLICED_V_QUANTIZED);
7851
}
7952

80-
void llamaSyncAttV(TASK_ARGS) {
53+
void llamaSyncQkv(TASK_ARGS) {
8154
TASK_VARIABLES;
55+
syncSliceOfSlicedBuffer(nThreads, threadIndex, ctx, TB_SLICED_Q_QUANTIZED);
56+
syncSliceOfSlicedBuffer(nThreads, threadIndex, ctx, TB_SLICED_K_QUANTIZED);
8257
syncSliceOfSlicedBuffer(nThreads, threadIndex, ctx, TB_SLICED_V_QUANTIZED);
58+
// if (ctx->socketPool != NULL && threadIndex == 0) { float* v = (float*)block->q0; printf("q0 (%d): %f %f %f %f %f %f\n", ctx->currentBlockIndex, v[0], v[1], v[2], v[3], v[4], v[5]); }
8359
}
8460

8561
void llamaDequantizeQkv(TASK_ARGS) {
@@ -316,15 +292,9 @@ TransformerArch buildLlama2Arch(TransformerSpec* spec) {
316292
a.I(llamaRmsAttNorm, TASK_TYPE_INFERENCE);
317293
a.I(llamaQuantizeRmsAtt, TASK_TYPE_INFERENCE);
318294
a.I(llamaSyncRmsAtt, TASK_TYPE_TRANSFER);
319-
a.I(llamaAttQ, TASK_TYPE_INFERENCE);
320-
a.I(llamaQuantizeAttQ, TASK_TYPE_INFERENCE);
321-
a.I(llamaSyncAttQ, TASK_TYPE_TRANSFER);
322-
a.I(llamaAttK, TASK_TYPE_INFERENCE);
323-
a.I(llamaQuantizeAttK, TASK_TYPE_INFERENCE);
324-
a.I(llamaSyncAttK, TASK_TYPE_TRANSFER);
325-
a.I(llamaAttV, TASK_TYPE_INFERENCE);
326-
a.I(llamaQuantizeAttV, TASK_TYPE_INFERENCE);
327-
a.I(llamaSyncAttV, TASK_TYPE_TRANSFER);
295+
a.I(llamaQkv, TASK_TYPE_INFERENCE);
296+
a.I(llamaQuantizeQkv, TASK_TYPE_INFERENCE);
297+
a.I(llamaSyncQkv, TASK_TYPE_TRANSFER);
328298
a.I(llamaDequantizeQkv, TASK_TYPE_INFERENCE);
329299
a.I(llamaMultiheadAtt, TASK_TYPE_INFERENCE);
330300
a.I(llamaMultiheadAttRope, TASK_TYPE_INFERENCE);
@@ -359,15 +329,9 @@ TransformerArch buildLlama2Arch(TransformerSpec* spec) {
359329

360330
for (int i = 0; i < spec->nLayers; i++) {
361331
a.W(llamaSyncRmsAtt, TASK_TYPE_TRANSFER);
362-
a.W(llamaAttQ, TASK_TYPE_INFERENCE);
363-
a.W(llamaQuantizeAttQ, TASK_TYPE_INFERENCE);
364-
a.W(llamaSyncAttQ, TASK_TYPE_INFERENCE);
365-
a.W(llamaAttK, TASK_TYPE_INFERENCE);
366-
a.W(llamaQuantizeAttK, TASK_TYPE_INFERENCE);
367-
a.W(llamaSyncAttK, TASK_TYPE_INFERENCE);
368-
a.W(llamaAttV, TASK_TYPE_INFERENCE);
369-
a.W(llamaQuantizeAttV, TASK_TYPE_INFERENCE);
370-
a.W(llamaSyncAttV, TASK_TYPE_INFERENCE);
332+
a.W(llamaQkv, TASK_TYPE_INFERENCE);
333+
a.W(llamaQuantizeQkv, TASK_TYPE_INFERENCE);
334+
a.W(llamaSyncQkv, TASK_TYPE_TRANSFER);
371335
a.W(llamaSyncMultiheadAtt, TASK_TYPE_TRANSFER);
372336
a.W(llamaAtt, TASK_TYPE_INFERENCE);
373337
a.W(llamaQuantizeAtt, TASK_TYPE_INFERENCE);

src/llama2-tasks.hpp

+3-9
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,9 @@ void llamaRmsAtt(TASK_ARGS);
77
void llamaRmsAttNorm(TASK_ARGS);
88
void llamaQuantizeRmsAtt(TASK_ARGS);
99
void llamaSyncRmsAtt(TASK_ARGS);
10-
void llamaAttQ(TASK_ARGS);
11-
void llamaQuantizeAttQ(TASK_ARGS);
12-
void llamaSyncAttQ(TASK_ARGS);
13-
void llamaAttK(TASK_ARGS);
14-
void llamaQuantizeAttK(TASK_ARGS);
15-
void llamaSyncAttK(TASK_ARGS);
16-
void llamaAttV(TASK_ARGS);
17-
void llamaQuantizeAttV(TASK_ARGS);
18-
void llamaSyncAttV(TASK_ARGS);
10+
void llamaQkv(TASK_ARGS);
11+
void llamaQuantizeQkv(TASK_ARGS);
12+
void llamaSyncQkv(TASK_ARGS);
1913
void llamaDequantizeQkv(TASK_ARGS);
2014
void llamaMultiheadAtt(TASK_ARGS);
2115
void llamaMultiheadAttRope(TASK_ARGS);

src/mixtral-tasks.cpp

+6-18
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,9 @@ TransformerArch buildMixtralArch(TransformerSpec* spec) {
1313
a.I(llamaRmsAttNorm, TASK_TYPE_INFERENCE);
1414
a.I(llamaQuantizeRmsAtt, TASK_TYPE_INFERENCE);
1515
a.I(llamaSyncRmsAtt, TASK_TYPE_TRANSFER);
16-
a.I(llamaAttQ, TASK_TYPE_INFERENCE);
17-
a.I(llamaQuantizeAttQ, TASK_TYPE_INFERENCE);
18-
a.I(llamaSyncAttQ, TASK_TYPE_TRANSFER);
19-
a.I(llamaAttK, TASK_TYPE_INFERENCE);
20-
a.I(llamaQuantizeAttK, TASK_TYPE_INFERENCE);
21-
a.I(llamaSyncAttK, TASK_TYPE_TRANSFER);
22-
a.I(llamaAttV, TASK_TYPE_INFERENCE);
23-
a.I(llamaQuantizeAttV, TASK_TYPE_INFERENCE);
24-
a.I(llamaSyncAttV, TASK_TYPE_TRANSFER);
16+
a.I(llamaQkv, TASK_TYPE_INFERENCE);
17+
a.I(llamaQuantizeQkv, TASK_TYPE_INFERENCE);
18+
a.I(llamaSyncQkv, TASK_TYPE_TRANSFER);
2519
a.I(llamaDequantizeQkv, TASK_TYPE_INFERENCE);
2620
a.I(llamaMultiheadAtt, TASK_TYPE_INFERENCE);
2721
a.I(grokMultiheadAttRope, TASK_TYPE_INFERENCE);
@@ -64,15 +58,9 @@ TransformerArch buildMixtralArch(TransformerSpec* spec) {
6458

6559
for (int i = 0; i < spec->nLayers; i++) {
6660
a.W(llamaSyncRmsAtt, TASK_TYPE_TRANSFER);
67-
a.W(llamaAttQ, TASK_TYPE_INFERENCE);
68-
a.W(llamaQuantizeAttQ, TASK_TYPE_INFERENCE);
69-
a.W(llamaSyncAttQ, TASK_TYPE_INFERENCE);
70-
a.W(llamaAttK, TASK_TYPE_INFERENCE);
71-
a.W(llamaQuantizeAttK, TASK_TYPE_INFERENCE);
72-
a.W(llamaSyncAttK, TASK_TYPE_INFERENCE);
73-
a.W(llamaAttV, TASK_TYPE_INFERENCE);
74-
a.W(llamaQuantizeAttV, TASK_TYPE_INFERENCE);
75-
a.W(llamaSyncAttV, TASK_TYPE_INFERENCE);
61+
a.W(llamaQkv, TASK_TYPE_INFERENCE);
62+
a.W(llamaQuantizeQkv, TASK_TYPE_INFERENCE);
63+
a.W(llamaSyncQkv, TASK_TYPE_TRANSFER);
7664
a.W(llamaSyncMultiheadAtt, TASK_TYPE_TRANSFER);
7765
a.W(llamaAtt, TASK_TYPE_INFERENCE);
7866
a.W(llamaQuantizeAtt, TASK_TYPE_INFERENCE);

0 commit comments

Comments
 (0)