Skip to content

Commit c96f663

Browse files
committed
fix: att offset.
1 parent b91efdd commit c96f663

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

src/nn/nn-cpu-ops.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -748,7 +748,7 @@ static float dotProduct_F32(const float *a, const float *b, const unsigned int s
748748

749749
static void multiheadAtt_F32(
750750
float *x, const float *q, float *att, float *keyCache, float *valueCache,
751-
const unsigned pos, const NnUint nHeads, const NnUint nHeads0, const NnUint nKvHeads, const NnUint kvDim0, const NnUint headSize, const NnUint seqLen,
751+
const NnUint pos, const NnUint nHeads, const NnUint nHeads0, const NnUint nKvHeads, const NnUint kvDim0, const NnUint headSize, const NnUint seqLen,
752752
const NnUint nThreads, const NnUint threadIndex)
753753
{
754754
SPLIT_THREADS(h0Start, h0End, nHeads0, nThreads, threadIndex);
@@ -1150,7 +1150,9 @@ static void multiHeadAttForward_F32_F32(NnUint nThreads, NnUint threadIndex, NnU
11501150
DEBUG_VECTOR(context, "input", i);
11511151
DEBUG_VECTOR(context, "q", q);
11521152

1153-
multiheadAtt_F32(i, q, att, keyCache, valueCache, pos,
1153+
multiheadAtt_F32(i, q,
1154+
&att[batchIndex * config->nHeads0 * config->seqLen],
1155+
keyCache, valueCache, pos,
11541156
config->nHeads, config->nHeads0,
11551157
config->nKvHeads, config->kvDim0, config->headSize, config->seqLen, nThreads, threadIndex);
11561158

0 commit comments

Comments
 (0)