Skip to content

Commit 69b4d9e

Browse files
committed
multithread rope.
1 parent 9dbfa9f commit 69b4d9e

File tree

3 files changed

+51
-21
lines changed

3 files changed

+51
-21
lines changed

src/llama2-tasks.cpp

+3-21
Original file line numberDiff line numberDiff line change
@@ -78,27 +78,9 @@ void llamaMultiheadAtt(TASK_ARGS) {
7878

7979
void llamaMultiheadAttRope(TASK_ARGS) {
8080
TASK_VARIABLES;
81-
if (threadIndex == 0) {
82-
float* q = (float*)transformer->buffer->getUnit(TB_SLICED_Q);
83-
float* k = block->keyCache + transformer->pos * spec->kvDim;
84-
85-
// RoPE relative positional encoding: complex-valued rotate q and k in each head
86-
for (int i = 0; i < spec->dim; i+=2) {
87-
int head_dim = i % spec->headSize;
88-
float freq = 1.0f / powf(spec->ropeTheta, head_dim / (float)spec->headSize);
89-
float val = transformer->pos * freq;
90-
float fcr = cosf(val);
91-
float fci = sinf(val);
92-
int rotn = i < spec->kvDim ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only
93-
for (int _v = 0; _v < rotn; _v++) {
94-
float* vec = _v == 0 ? q : k; // the vector to rotate (query or key)
95-
float v0 = vec[i];
96-
float v1 = vec[i+1];
97-
vec[i] = v0 * fcr - v1 * fci;
98-
vec[i+1] = v0 * fci + v1 * fcr;
99-
}
100-
}
101-
}
81+
float* q = (float*)transformer->buffer->getUnit(TB_SLICED_Q);
82+
float* k = block->keyCache + transformer->pos * spec->kvDim;
83+
rope(transformer->ropeCache, q, k, spec, transformer->pos, nThreads, threadIndex);
10284
}
10385

10486
void llamaMultiheadAttJoin(TASK_ARGS) {

src/transformer.cpp

+44
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,40 @@ long MatmulSlice::mergeOutputs(uint8_t sliceIndex, float* output, float* output0
5151
return offset; // offset in floats
5252
}
5353

54+
void initRope(float* cache, TransformerSpec* spec) {
55+
for (int pos = 0; pos < spec->seqLen; pos++) {
56+
for (int i = 0; i < spec->dim; i += 2) {
57+
int head_dim = i % spec->headSize;
58+
float freq = 1.0f / powf(spec->ropeTheta, head_dim / (float)spec->headSize);
59+
float val = pos * freq;
60+
float fcr = cosf(val);
61+
float fci = sinf(val);
62+
cache[pos * spec->seqLen + i] = fcr;
63+
cache[pos * spec->seqLen + i + 1] = fci;
64+
}
65+
}
66+
}
67+
68+
void rope(float* cache, float* q, float* k, TransformerSpec* spec, int pos, unsigned int nThreads, unsigned int threadIndex) {
69+
int slice = spec->dim / (nThreads * 2);
70+
int iStart = (threadIndex * slice) * 2;
71+
int iEnd = ((nThreads - 1 == threadIndex) ? spec->dim : (iStart + slice)) * 2;
72+
73+
// RoPE relative positional encoding: complex-valued rotate q and k in each head
74+
for (int i = iStart; i < iEnd; i += 2) {
75+
float fcr = cache[pos * spec->seqLen + i];
76+
float fci = cache[pos * spec->seqLen + i + 1];
77+
int rotn = i < spec->kvDim ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only
78+
for (int _v = 0; _v < rotn; _v++) {
79+
float* vec = _v == 0 ? q : k; // the vector to rotate (query or key)
80+
float v0 = vec[i];
81+
float v1 = vec[i+1];
82+
vec[i] = v0 * fcr - v1 * fci;
83+
vec[i+1] = v0 * fci + v1 * fcr;
84+
}
85+
}
86+
}
87+
5488
TransformerSpec Transformer::loadSpecFromFile(const char* path, const unsigned int nSlices, FloatType weightsFloatType, FloatType bufferFloatType) {
5589
TransformerSpec spec;
5690
memset(&spec, 0, sizeof(TransformerSpec));
@@ -252,6 +286,12 @@ Transformer::Transformer(TransformerSpec* spec, uint8_t sliceIndex) {
252286
#endif
253287
x = (float*)NEW_BUFFER(spec->dim * sizeof(float));
254288
logits = (float*)NEW_BUFFER(spec->vocabSize * sizeof(float));
289+
290+
// TODO: cache should be for all architectures
291+
if (spec->archType == LLAMA2 || spec->archType == MIXTRAL) {
292+
ropeCache = (float*)NEW_BUFFER(spec->vocabSize * spec->dim);
293+
initRope(ropeCache, spec);
294+
}
255295
}
256296
}
257297

@@ -270,6 +310,10 @@ Transformer::~Transformer() {
270310
#endif
271311
FREE_BUFFER(x);
272312
FREE_BUFFER(logits);
313+
314+
if (spec->archType == LLAMA2 || spec->archType == MIXTRAL) {
315+
FREE_BUFFER(ropeCache);
316+
}
273317
}
274318
}
275319

src/transformer.hpp

+4
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@ struct TransformerSpec {
8383
uint8_t nSlices;
8484
};
8585

86+
void initRope(float* cache, TransformerSpec* spec);
87+
void rope(float* cache, float* q, float* k, TransformerSpec* spec, int pos, unsigned int nThreads, unsigned int threadIndex);
88+
8689
class TransformerBlock {
8790
public:
8891
uint8_t sliceIndex;
@@ -186,6 +189,7 @@ class Transformer {
186189
int pos;
187190
float* x;
188191
float* logits;
192+
float* ropeCache;
189193

190194
~Transformer();
191195

0 commit comments

Comments
 (0)