From 10f26de13dcc073e2c6b3d654a22834c2d156860 Mon Sep 17 00:00:00 2001 From: b4rtaz Date: Mon, 29 Apr 2024 22:40:57 +0200 Subject: [PATCH 1/2] Revert "fix task types." This reverts commit 7f63f9ed82ffa64fa7b0cd6144edec9c2b524761. --- src/grok1-tasks.cpp | 6 +++--- src/llama2-tasks.cpp | 6 +++--- src/mixtral-tasks.cpp | 6 +++--- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/grok1-tasks.cpp b/src/grok1-tasks.cpp index 95a11b46..ca9cab70 100644 --- a/src/grok1-tasks.cpp +++ b/src/grok1-tasks.cpp @@ -310,13 +310,13 @@ TransformerArch buildGrok1Arch(TransformerSpec* spec) { a.I(llamaSyncRmsAtt, TASK_TYPE_TRANSFER); a.I(llamaAttQ, TASK_TYPE_INFERENCE); a.I(llamaQuantizeAttQ, TASK_TYPE_INFERENCE); - a.I(llamaSyncAttQ, TASK_TYPE_TRANSFER); + a.I(llamaSyncAttQ, TASK_TYPE_INFERENCE); a.I(llamaAttK, TASK_TYPE_INFERENCE); a.I(llamaQuantizeAttK, TASK_TYPE_INFERENCE); - a.I(llamaSyncAttK, TASK_TYPE_TRANSFER); + a.I(llamaSyncAttK, TASK_TYPE_INFERENCE); a.I(llamaAttV, TASK_TYPE_INFERENCE); a.I(llamaQuantizeAttV, TASK_TYPE_INFERENCE); - a.I(llamaSyncAttV, TASK_TYPE_TRANSFER); + a.I(llamaSyncAttV, TASK_TYPE_INFERENCE); a.I(llamaDequantizeQkv, TASK_TYPE_INFERENCE); a.I(llamaMultiheadAtt, TASK_TYPE_INFERENCE); a.I(grokMultiheadAttRope, TASK_TYPE_INFERENCE); diff --git a/src/llama2-tasks.cpp b/src/llama2-tasks.cpp index 69411abb..1588f4e5 100644 --- a/src/llama2-tasks.cpp +++ b/src/llama2-tasks.cpp @@ -318,13 +318,13 @@ TransformerArch buildLlama2Arch(TransformerSpec* spec) { a.I(llamaSyncRmsAtt, TASK_TYPE_TRANSFER); a.I(llamaAttQ, TASK_TYPE_INFERENCE); a.I(llamaQuantizeAttQ, TASK_TYPE_INFERENCE); - a.I(llamaSyncAttQ, TASK_TYPE_TRANSFER); + a.I(llamaSyncAttQ, TASK_TYPE_INFERENCE); a.I(llamaAttK, TASK_TYPE_INFERENCE); a.I(llamaQuantizeAttK, TASK_TYPE_INFERENCE); - a.I(llamaSyncAttK, TASK_TYPE_TRANSFER); + a.I(llamaSyncAttK, TASK_TYPE_INFERENCE); a.I(llamaAttV, TASK_TYPE_INFERENCE); a.I(llamaQuantizeAttV, TASK_TYPE_INFERENCE); - a.I(llamaSyncAttV, TASK_TYPE_TRANSFER); + a.I(llamaSyncAttV, TASK_TYPE_INFERENCE); a.I(llamaDequantizeQkv, TASK_TYPE_INFERENCE); a.I(llamaMultiheadAtt, TASK_TYPE_INFERENCE); a.I(llamaMultiheadAttRope, TASK_TYPE_INFERENCE); diff --git a/src/mixtral-tasks.cpp b/src/mixtral-tasks.cpp index 0da1cf2d..6847f1d0 100644 --- a/src/mixtral-tasks.cpp +++ b/src/mixtral-tasks.cpp @@ -15,13 +15,13 @@ TransformerArch buildMixtralArch(TransformerSpec* spec) { a.I(llamaSyncRmsAtt, TASK_TYPE_TRANSFER); a.I(llamaAttQ, TASK_TYPE_INFERENCE); a.I(llamaQuantizeAttQ, TASK_TYPE_INFERENCE); - a.I(llamaSyncAttQ, TASK_TYPE_TRANSFER); + a.I(llamaSyncAttQ, TASK_TYPE_INFERENCE); a.I(llamaAttK, TASK_TYPE_INFERENCE); a.I(llamaQuantizeAttK, TASK_TYPE_INFERENCE); - a.I(llamaSyncAttK, TASK_TYPE_TRANSFER); + a.I(llamaSyncAttK, TASK_TYPE_INFERENCE); a.I(llamaAttV, TASK_TYPE_INFERENCE); a.I(llamaQuantizeAttV, TASK_TYPE_INFERENCE); - a.I(llamaSyncAttV, TASK_TYPE_TRANSFER); + a.I(llamaSyncAttV, TASK_TYPE_INFERENCE); a.I(llamaDequantizeQkv, TASK_TYPE_INFERENCE); a.I(llamaMultiheadAtt, TASK_TYPE_INFERENCE); a.I(grokMultiheadAttRope, TASK_TYPE_INFERENCE); From 9732f42d6177e39d63b7d636693a5c3033e9585a Mon Sep 17 00:00:00 2001 From: b4rtaz Date: Mon, 29 Apr 2024 22:42:38 +0200 Subject: [PATCH 2/2] revert qkv. --- src/grok1-tasks.cpp | 24 ++++----------- src/llama2-tasks.cpp | 72 +++++++++++-------------------------------- src/llama2-tasks.hpp | 12 ++------ src/mixtral-tasks.cpp | 24 ++++----------- 4 files changed, 33 insertions(+), 99 deletions(-) diff --git a/src/grok1-tasks.cpp b/src/grok1-tasks.cpp index ca9cab70..6cd70874 100644 --- a/src/grok1-tasks.cpp +++ b/src/grok1-tasks.cpp @@ -308,15 +308,9 @@ TransformerArch buildGrok1Arch(TransformerSpec* spec) { a.I(llamaRmsAttNorm, TASK_TYPE_INFERENCE); a.I(llamaQuantizeRmsAtt, TASK_TYPE_INFERENCE); a.I(llamaSyncRmsAtt, TASK_TYPE_TRANSFER); - a.I(llamaAttQ, TASK_TYPE_INFERENCE); - a.I(llamaQuantizeAttQ, TASK_TYPE_INFERENCE); - a.I(llamaSyncAttQ, TASK_TYPE_INFERENCE); - a.I(llamaAttK, TASK_TYPE_INFERENCE); - a.I(llamaQuantizeAttK, TASK_TYPE_INFERENCE); - a.I(llamaSyncAttK, TASK_TYPE_INFERENCE); - a.I(llamaAttV, TASK_TYPE_INFERENCE); - a.I(llamaQuantizeAttV, TASK_TYPE_INFERENCE); - a.I(llamaSyncAttV, TASK_TYPE_INFERENCE); + a.I(llamaQkv, TASK_TYPE_INFERENCE); + a.I(llamaQuantizeQkv, TASK_TYPE_INFERENCE); + a.I(llamaSyncQkv, TASK_TYPE_TRANSFER); a.I(llamaDequantizeQkv, TASK_TYPE_INFERENCE); a.I(llamaMultiheadAtt, TASK_TYPE_INFERENCE); a.I(grokMultiheadAttRope, TASK_TYPE_INFERENCE); @@ -364,15 +358,9 @@ TransformerArch buildGrok1Arch(TransformerSpec* spec) { for (int i = 0; i < spec->nLayers; i++) { a.W(llamaSyncRmsAtt, TASK_TYPE_TRANSFER); - a.W(llamaAttQ, TASK_TYPE_INFERENCE); - a.W(llamaQuantizeAttQ, TASK_TYPE_INFERENCE); - a.W(llamaSyncAttQ, TASK_TYPE_INFERENCE); - a.W(llamaAttK, TASK_TYPE_INFERENCE); - a.W(llamaQuantizeAttK, TASK_TYPE_INFERENCE); - a.W(llamaSyncAttK, TASK_TYPE_INFERENCE); - a.W(llamaAttV, TASK_TYPE_INFERENCE); - a.W(llamaQuantizeAttV, TASK_TYPE_INFERENCE); - a.W(llamaSyncAttV, TASK_TYPE_INFERENCE); + a.W(llamaQkv, TASK_TYPE_INFERENCE); + a.W(llamaQuantizeQkv, TASK_TYPE_INFERENCE); + a.W(llamaSyncQkv, TASK_TYPE_TRANSFER); a.W(llamaSyncMultiheadAtt, TASK_TYPE_TRANSFER); a.W(llamaAtt, TASK_TYPE_INFERENCE); a.W(llamaQuantizeAtt, TASK_TYPE_INFERENCE); diff --git a/src/llama2-tasks.cpp b/src/llama2-tasks.cpp index 1588f4e5..cb77565a 100644 --- a/src/llama2-tasks.cpp +++ b/src/llama2-tasks.cpp @@ -30,56 +30,32 @@ void llamaSyncRmsAtt(TASK_ARGS) { syncUnitBuffer(nThreads, threadIndex, ctx, TB_UNIT_XB_QUANTIZED); } -void llamaAttQ(TASK_ARGS) { +void llamaQkv(TASK_ARGS) { TASK_VARIABLES; - float *xbq = (float*)transformer->buffer->getUnit(TB_UNIT_XB_QUANTIZED); - float *q0 = (float*)transformer->buffer->getSliced(TB_SLICED_Q, transformer->sliceIndex); - matmul(spec->weightsFloatType, spec->bufferFloatType, q0, xbq, block->q0, block->q0Slice->n, block->q0Slice->d0, nThreads, threadIndex); -} - -void llamaQuantizeAttQ(TASK_ARGS) { - TASK_VARIABLES; - quantizeSlicedBuffer(nThreads, threadIndex, ctx, false, TB_SLICED_Q, TB_SLICED_Q_QUANTIZED); -} - -void llamaSyncAttQ(TASK_ARGS) { - TASK_VARIABLES; - syncSliceOfSlicedBuffer(nThreads, threadIndex, ctx, TB_SLICED_Q_QUANTIZED); -} -void llamaAttK(TASK_ARGS) { - TASK_VARIABLES; float *xbq = (float*)transformer->buffer->getUnit(TB_UNIT_XB_QUANTIZED); + float *q0 = (float*)transformer->buffer->getSliced(TB_SLICED_Q, transformer->sliceIndex); float *k0 = (float*)transformer->buffer->getSliced(TB_SLICED_K, transformer->sliceIndex); - matmul(spec->weightsFloatType, spec->bufferFloatType, k0, xbq, block->k0, block->k0Slice->n, block->k0Slice->d0, nThreads, threadIndex); -} - -void llamaQuantizeAttK(TASK_ARGS) { - TASK_VARIABLES; - quantizeSlicedBuffer(nThreads, threadIndex, ctx, false, TB_SLICED_K, TB_SLICED_K_QUANTIZED); -} - -void llamaSyncAttK(TASK_ARGS) { - TASK_VARIABLES; - syncSliceOfSlicedBuffer(nThreads, threadIndex, ctx, TB_SLICED_K_QUANTIZED); -} - -void llamaAttV(TASK_ARGS) { - TASK_VARIABLES; - float *xbq = (float*)transformer->buffer->getUnit(TB_UNIT_XB_QUANTIZED); float *v0 = (float*)transformer->buffer->getSliced(TB_SLICED_V, transformer->sliceIndex); - matmul(spec->weightsFloatType, spec->bufferFloatType, v0, xbq, block->v0, block->v0Slice->n, block->v0Slice->d0, nThreads, threadIndex); + matmul(spec->weightsFloatType, spec->bufferFloatType, q0, xbq, block->q0, block->q0Slice->n, block->q0Slice->d0, nThreads, threadIndex); + matmul(spec->weightsFloatType, spec->bufferFloatType, k0, xbq, block->k0, block->k0Slice->n, block->k0Slice->d0, nThreads, threadIndex); + matmul(spec->weightsFloatType, spec->bufferFloatType, v0, xbq, block->v0, block->v0Slice->n, block->v0Slice->d0, nThreads, threadIndex); } -void llamaQuantizeAttV(TASK_ARGS) { +void llamaQuantizeQkv(TASK_ARGS) { TASK_VARIABLES; + quantizeSlicedBuffer(nThreads, threadIndex, ctx, false, TB_SLICED_Q, TB_SLICED_Q_QUANTIZED); + quantizeSlicedBuffer(nThreads, threadIndex, ctx, false, TB_SLICED_K, TB_SLICED_K_QUANTIZED); quantizeSlicedBuffer(nThreads, threadIndex, ctx, false, TB_SLICED_V, TB_SLICED_V_QUANTIZED); } -void llamaSyncAttV(TASK_ARGS) { +void llamaSyncQkv(TASK_ARGS) { TASK_VARIABLES; + syncSliceOfSlicedBuffer(nThreads, threadIndex, ctx, TB_SLICED_Q_QUANTIZED); + syncSliceOfSlicedBuffer(nThreads, threadIndex, ctx, TB_SLICED_K_QUANTIZED); syncSliceOfSlicedBuffer(nThreads, threadIndex, ctx, TB_SLICED_V_QUANTIZED); + // if (ctx->socketPool != NULL && threadIndex == 0) { float* v = (float*)block->q0; printf("q0 (%d): %f %f %f %f %f %f\n", ctx->currentBlockIndex, v[0], v[1], v[2], v[3], v[4], v[5]); } } void llamaDequantizeQkv(TASK_ARGS) { @@ -316,15 +292,9 @@ TransformerArch buildLlama2Arch(TransformerSpec* spec) { a.I(llamaRmsAttNorm, TASK_TYPE_INFERENCE); a.I(llamaQuantizeRmsAtt, TASK_TYPE_INFERENCE); a.I(llamaSyncRmsAtt, TASK_TYPE_TRANSFER); - a.I(llamaAttQ, TASK_TYPE_INFERENCE); - a.I(llamaQuantizeAttQ, TASK_TYPE_INFERENCE); - a.I(llamaSyncAttQ, TASK_TYPE_INFERENCE); - a.I(llamaAttK, TASK_TYPE_INFERENCE); - a.I(llamaQuantizeAttK, TASK_TYPE_INFERENCE); - a.I(llamaSyncAttK, TASK_TYPE_INFERENCE); - a.I(llamaAttV, TASK_TYPE_INFERENCE); - a.I(llamaQuantizeAttV, TASK_TYPE_INFERENCE); - a.I(llamaSyncAttV, TASK_TYPE_INFERENCE); + a.I(llamaQkv, TASK_TYPE_INFERENCE); + a.I(llamaQuantizeQkv, TASK_TYPE_INFERENCE); + a.I(llamaSyncQkv, TASK_TYPE_TRANSFER); a.I(llamaDequantizeQkv, TASK_TYPE_INFERENCE); a.I(llamaMultiheadAtt, TASK_TYPE_INFERENCE); a.I(llamaMultiheadAttRope, TASK_TYPE_INFERENCE); @@ -359,15 +329,9 @@ TransformerArch buildLlama2Arch(TransformerSpec* spec) { for (int i = 0; i < spec->nLayers; i++) { a.W(llamaSyncRmsAtt, TASK_TYPE_TRANSFER); - a.W(llamaAttQ, TASK_TYPE_INFERENCE); - a.W(llamaQuantizeAttQ, TASK_TYPE_INFERENCE); - a.W(llamaSyncAttQ, TASK_TYPE_INFERENCE); - a.W(llamaAttK, TASK_TYPE_INFERENCE); - a.W(llamaQuantizeAttK, TASK_TYPE_INFERENCE); - a.W(llamaSyncAttK, TASK_TYPE_INFERENCE); - a.W(llamaAttV, TASK_TYPE_INFERENCE); - a.W(llamaQuantizeAttV, TASK_TYPE_INFERENCE); - a.W(llamaSyncAttV, TASK_TYPE_INFERENCE); + a.W(llamaQkv, TASK_TYPE_INFERENCE); + a.W(llamaQuantizeQkv, TASK_TYPE_INFERENCE); + a.W(llamaSyncQkv, TASK_TYPE_TRANSFER); a.W(llamaSyncMultiheadAtt, TASK_TYPE_TRANSFER); a.W(llamaAtt, TASK_TYPE_INFERENCE); a.W(llamaQuantizeAtt, TASK_TYPE_INFERENCE); diff --git a/src/llama2-tasks.hpp b/src/llama2-tasks.hpp index 4d4984cb..91ddafbc 100644 --- a/src/llama2-tasks.hpp +++ b/src/llama2-tasks.hpp @@ -7,15 +7,9 @@ void llamaRmsAtt(TASK_ARGS); void llamaRmsAttNorm(TASK_ARGS); void llamaQuantizeRmsAtt(TASK_ARGS); void llamaSyncRmsAtt(TASK_ARGS); -void llamaAttQ(TASK_ARGS); -void llamaQuantizeAttQ(TASK_ARGS); -void llamaSyncAttQ(TASK_ARGS); -void llamaAttK(TASK_ARGS); -void llamaQuantizeAttK(TASK_ARGS); -void llamaSyncAttK(TASK_ARGS); -void llamaAttV(TASK_ARGS); -void llamaQuantizeAttV(TASK_ARGS); -void llamaSyncAttV(TASK_ARGS); +void llamaQkv(TASK_ARGS); +void llamaQuantizeQkv(TASK_ARGS); +void llamaSyncQkv(TASK_ARGS); void llamaDequantizeQkv(TASK_ARGS); void llamaMultiheadAtt(TASK_ARGS); void llamaMultiheadAttRope(TASK_ARGS); diff --git a/src/mixtral-tasks.cpp b/src/mixtral-tasks.cpp index 6847f1d0..4f017489 100644 --- a/src/mixtral-tasks.cpp +++ b/src/mixtral-tasks.cpp @@ -13,15 +13,9 @@ TransformerArch buildMixtralArch(TransformerSpec* spec) { a.I(llamaRmsAttNorm, TASK_TYPE_INFERENCE); a.I(llamaQuantizeRmsAtt, TASK_TYPE_INFERENCE); a.I(llamaSyncRmsAtt, TASK_TYPE_TRANSFER); - a.I(llamaAttQ, TASK_TYPE_INFERENCE); - a.I(llamaQuantizeAttQ, TASK_TYPE_INFERENCE); - a.I(llamaSyncAttQ, TASK_TYPE_INFERENCE); - a.I(llamaAttK, TASK_TYPE_INFERENCE); - a.I(llamaQuantizeAttK, TASK_TYPE_INFERENCE); - a.I(llamaSyncAttK, TASK_TYPE_INFERENCE); - a.I(llamaAttV, TASK_TYPE_INFERENCE); - a.I(llamaQuantizeAttV, TASK_TYPE_INFERENCE); - a.I(llamaSyncAttV, TASK_TYPE_INFERENCE); + a.I(llamaQkv, TASK_TYPE_INFERENCE); + a.I(llamaQuantizeQkv, TASK_TYPE_INFERENCE); + a.I(llamaSyncQkv, TASK_TYPE_TRANSFER); a.I(llamaDequantizeQkv, TASK_TYPE_INFERENCE); a.I(llamaMultiheadAtt, TASK_TYPE_INFERENCE); a.I(grokMultiheadAttRope, TASK_TYPE_INFERENCE); @@ -64,15 +58,9 @@ TransformerArch buildMixtralArch(TransformerSpec* spec) { for (int i = 0; i < spec->nLayers; i++) { a.W(llamaSyncRmsAtt, TASK_TYPE_TRANSFER); - a.W(llamaAttQ, TASK_TYPE_INFERENCE); - a.W(llamaQuantizeAttQ, TASK_TYPE_INFERENCE); - a.W(llamaSyncAttQ, TASK_TYPE_INFERENCE); - a.W(llamaAttK, TASK_TYPE_INFERENCE); - a.W(llamaQuantizeAttK, TASK_TYPE_INFERENCE); - a.W(llamaSyncAttK, TASK_TYPE_INFERENCE); - a.W(llamaAttV, TASK_TYPE_INFERENCE); - a.W(llamaQuantizeAttV, TASK_TYPE_INFERENCE); - a.W(llamaSyncAttV, TASK_TYPE_INFERENCE); + a.W(llamaQkv, TASK_TYPE_INFERENCE); + a.W(llamaQuantizeQkv, TASK_TYPE_INFERENCE); + a.W(llamaSyncQkv, TASK_TYPE_TRANSFER); a.W(llamaSyncMultiheadAtt, TASK_TYPE_TRANSFER); a.W(llamaAtt, TASK_TYPE_INFERENCE); a.W(llamaQuantizeAtt, TASK_TYPE_INFERENCE);