Skip to content

Commit 1aea4df

Browse files
committed
eureka!
1 parent d89e532 commit 1aea4df

File tree

4 files changed

+41
-30
lines changed

4 files changed

+41
-30
lines changed

Diff for: src/llama2-tasks.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ void llamaRope(TASK_ARGS) {
4747
TASK_VARIABLES;
4848
float* q = (float*)transformer->buffer->getSliced(TB_SLICED_Q, transformer->sliceIndex);
4949
float* k = (float*)transformer->buffer->getSliced(TB_SLICED_K, transformer->sliceIndex);
50-
transformer->ropeSlice->forward(q, k, transformer->pos, nThreads, threadIndex);
50+
transformer->ropeSlice->forward(true, q, transformer->pos, nThreads, threadIndex);
51+
transformer->ropeSlice->forward(false, k, transformer->pos, nThreads, threadIndex);
5152
}
5253

5354
void llamaQuantizeQkv(TASK_ARGS) {

Diff for: src/transformer-test.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ void testRopeSlice() {
1717
float* k = new float[spec.kvDim];
1818
float* correctQ = new float[spec.dim];
1919
float* correctK = new float[spec.kvDim];
20-
const int nSliceTests = 7;
21-
const int nPosTests = 8;
20+
const int nSliceTests = 5;
21+
const int nPosTests = 6;
2222
const int nThreadTests = 3;
2323

2424
for (int pos = 0; pos < spec.seqLen; pos += spec.seqLen / nPosTests) {

Diff for: src/transformer.cpp

+32-23
Original file line numberDiff line numberDiff line change
@@ -44,23 +44,34 @@ size_t MatmulSlice::splitWeights(uint8_t sliceIndex, char* weights, char* weight
4444
}
4545

4646
RopeSlice::RopeSlice(TransformerSpec* spec, uint8_t sliceIndex) {
47+
assert(spec->dim >= spec->kvDim);
4748
assert(spec->dim % spec->nSlices == 0);
48-
kvDim = spec->kvDim;
49-
dim0 = spec->dim / spec->nSlices;
50-
assert(dim0 % 2 == 0);
51-
dimOffset = dim0 * sliceIndex;
52-
size_t cacheBytes = spec->seqLen * dim0 * sizeof(float);
49+
assert(spec->kvDim % spec->nSlices == 0);
50+
51+
qDim0 = spec->dim / spec->nSlices;
52+
kvDim0 = spec->kvDim / spec->nSlices;
53+
assert(qDim0 % 2 == 0);
54+
assert(kvDim0 % 2 == 0);
55+
int kvDim0From = kvDim0 * sliceIndex;
56+
int qDim0From = qDim0 * sliceIndex;
57+
int qDim0To = qDim0From + qDim0;
58+
qOffset = qDim0From - kvDim0From;
59+
cacheDim = qDim0To - kvDim0From;
60+
assert(cacheDim % 2 == 0);
61+
62+
size_t cacheBytes = spec->seqLen * cacheDim * sizeof(float);
5363
cache = (float*)NEW_BUFFER(cacheBytes);
64+
printf("🕒 ropeCache: %ld bytes\n", cacheBytes);
5465

5566
for (pos_t pos = 0; pos < spec->seqLen; pos++) {
56-
for (int i = 0; i < dim0; i += 2) {
57-
int headDim = (i + dimOffset) % spec->headSize;
67+
for (int i = kvDim0From; i < qDim0To; i += 2) {
68+
int headDim = i % spec->headSize;
5869
float freq = 1.0f / powf(spec->ropeTheta, headDim / (float)spec->headSize);
5970
float val = pos * freq;
6071
float fcr = cosf(val);
6172
float fci = sinf(val);
62-
cache[pos * dim0 + i] = fcr;
63-
cache[pos * dim0 + i + 1] = fci;
73+
cache[pos * cacheDim + (i - kvDim0From)] = fcr;
74+
cache[pos * cacheDim + (i - kvDim0From) + 1] = fci;
6475
}
6576
}
6677
}
@@ -69,25 +80,23 @@ RopeSlice::~RopeSlice() {
6980
FREE_BUFFER(cache);
7081
}
7182

72-
void RopeSlice::forward(float* q, float* k, pos_t pos, unsigned int nThreads, unsigned int threadIndex) {
73-
int halfDim0 = dim0 / 2;
74-
int slice = halfDim0 / nThreads;
83+
void RopeSlice::forward(bool isQ, float* qOrV, pos_t pos, unsigned int nThreads, unsigned int threadIndex) {
84+
int d0 = isQ ? qDim0 : kvDim0;
85+
int offset = isQ ? qOffset : 0;
86+
int halfD0 = d0 / 2;
87+
int slice = halfD0 / nThreads;
7588
int iStart = threadIndex * slice;
76-
int iEnd = (nThreads - 1 == threadIndex) ? halfDim0 : (iStart + slice);
89+
int iEnd = (nThreads - 1 == threadIndex) ? halfD0 : (iStart + slice);
7790
iStart *= 2;
7891
iEnd *= 2;
7992

8093
for (int i = iStart; i < iEnd; i += 2) {
81-
float fcr = cache[pos * dim0 + i];
82-
float fci = cache[pos * dim0 + i + 1];
83-
int rotn = (dimOffset + i) < kvDim ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only
84-
for (int _v = 0; _v < rotn; _v++) {
85-
float* vec = _v == 0 ? q : k; // the vector to rotate (query or key)
86-
float v0 = vec[i];
87-
float v1 = vec[i+1];
88-
vec[i] = v0 * fcr - v1 * fci;
89-
vec[i+1] = v0 * fci + v1 * fcr;
90-
}
94+
float fcr = cache[pos * cacheDim + offset + i];
95+
float fci = cache[pos * cacheDim + offset + i + 1];
96+
float v0 = qOrV[i];
97+
float v1 = qOrV[i+1];
98+
qOrV[i] = v0 * fcr - v1 * fci;
99+
qOrV[i+1] = v0 * fci + v1 * fcr;
91100
}
92101
}
93102

Diff for: src/transformer.hpp

+5-4
Original file line numberDiff line numberDiff line change
@@ -87,13 +87,14 @@ struct TransformerSpec {
8787
class RopeSlice {
8888
private:
8989
float* cache;
90-
int kvDim;
90+
int cacheDim;
91+
int qDim0;
92+
int qOffset;
93+
int kvDim0;
9194
public:
92-
int dim0;
93-
int dimOffset;
9495
RopeSlice(TransformerSpec* spec, uint8_t sliceIndex);
9596
~RopeSlice();
96-
void forward(float* q, float* k, pos_t pos, unsigned int nThreads, unsigned int threadIndex);
97+
void forward(bool isQ, float* qOrV, pos_t pos, unsigned int nThreads, unsigned int threadIndex);
9798
};
9899

99100
class TransformerBlock {

0 commit comments

Comments
 (0)