@@ -51,6 +51,40 @@ long MatmulSlice::mergeOutputs(uint8_t sliceIndex, float* output, float* output0
51
51
return offset; // offset in floats
52
52
}
53
53
54
+ void initRope (float * cache, TransformerSpec* spec) {
55
+ for (int pos = 0 ; pos < spec->seqLen ; pos++) {
56
+ for (int i = 0 ; i < spec->dim ; i += 2 ) {
57
+ int head_dim = i % spec->headSize ;
58
+ float freq = 1 .0f / powf (spec->ropeTheta , head_dim / (float )spec->headSize );
59
+ float val = pos * freq;
60
+ float fcr = cosf (val);
61
+ float fci = sinf (val);
62
+ cache[pos * spec->seqLen + i] = fcr;
63
+ cache[pos * spec->seqLen + i + 1 ] = fci;
64
+ }
65
+ }
66
+ }
67
+
68
+ void rope (float * cache, float * q, float * k, TransformerSpec* spec, int pos, unsigned int nThreads, unsigned int threadIndex) {
69
+ int slice = spec->dim / (nThreads * 2 );
70
+ int iStart = (threadIndex * slice) * 2 ;
71
+ int iEnd = ((nThreads - 1 == threadIndex) ? spec->dim : (iStart + slice)) * 2 ;
72
+
73
+ // RoPE relative positional encoding: complex-valued rotate q and k in each head
74
+ for (int i = iStart; i < iEnd; i += 2 ) {
75
+ float fcr = cache[pos * spec->seqLen + i];
76
+ float fci = cache[pos * spec->seqLen + i + 1 ];
77
+ int rotn = i < spec->kvDim ? 2 : 1 ; // how many vectors? 2 = q & k, 1 = q only
78
+ for (int _v = 0 ; _v < rotn; _v++) {
79
+ float * vec = _v == 0 ? q : k; // the vector to rotate (query or key)
80
+ float v0 = vec[i];
81
+ float v1 = vec[i+1 ];
82
+ vec[i] = v0 * fcr - v1 * fci;
83
+ vec[i+1 ] = v0 * fci + v1 * fcr;
84
+ }
85
+ }
86
+ }
87
+
54
88
TransformerSpec Transformer::loadSpecFromFile (const char * path, const unsigned int nSlices, FloatType weightsFloatType, FloatType bufferFloatType) {
55
89
TransformerSpec spec;
56
90
memset (&spec, 0 , sizeof (TransformerSpec));
@@ -252,6 +286,12 @@ Transformer::Transformer(TransformerSpec* spec, uint8_t sliceIndex) {
252
286
#endif
253
287
x = (float *)NEW_BUFFER (spec->dim * sizeof (float ));
254
288
logits = (float *)NEW_BUFFER (spec->vocabSize * sizeof (float ));
289
+
290
+ // TODO: cache should be for all architectures
291
+ if (spec->archType == LLAMA2 || spec->archType == MIXTRAL) {
292
+ ropeCache = (float *)NEW_BUFFER (spec->vocabSize * spec->dim );
293
+ initRope (ropeCache, spec);
294
+ }
255
295
}
256
296
}
257
297
@@ -270,6 +310,10 @@ Transformer::~Transformer() {
270
310
#endif
271
311
FREE_BUFFER (x);
272
312
FREE_BUFFER (logits);
313
+
314
+ if (spec->archType == LLAMA2 || spec->archType == MIXTRAL) {
315
+ FREE_BUFFER (ropeCache);
316
+ }
273
317
}
274
318
}
275
319
0 commit comments