@@ -748,7 +748,7 @@ static float dotProduct_F32(const float *a, const float *b, const unsigned int s
748
748
749
749
static void multiheadAtt_F32 (
750
750
float *x, const float *q, float *att, float *keyCache, float *valueCache,
751
- const unsigned pos, const NnUint nHeads, const NnUint nHeads0, const NnUint nKvHeads, const NnUint kvDim0, const NnUint headSize, const NnUint seqLen,
751
+ const NnUint pos, const NnUint nHeads, const NnUint nHeads0, const NnUint nKvHeads, const NnUint kvDim0, const NnUint headSize, const NnUint seqLen,
752
752
const NnUint nThreads, const NnUint threadIndex)
753
753
{
754
754
SPLIT_THREADS (h0Start, h0End, nHeads0, nThreads, threadIndex);
@@ -1150,7 +1150,9 @@ static void multiHeadAttForward_F32_F32(NnUint nThreads, NnUint threadIndex, NnU
1150
1150
DEBUG_VECTOR (context, " input" , i);
1151
1151
DEBUG_VECTOR (context, " q" , q);
1152
1152
1153
- multiheadAtt_F32 (i, q, att, keyCache, valueCache, pos,
1153
+ multiheadAtt_F32 (i, q,
1154
+ &att[batchIndex * config->nHeads0 * config->seqLen ],
1155
+ keyCache, valueCache, pos,
1154
1156
config->nHeads , config->nHeads0 ,
1155
1157
config->nKvHeads , config->kvDim0 , config->headSize , config->seqLen , nThreads, threadIndex);
1156
1158
0 commit comments