@@ -44,23 +44,34 @@ size_t MatmulSlice::splitWeights(uint8_t sliceIndex, char* weights, char* weight
44
44
}
45
45
46
46
RopeSlice::RopeSlice (TransformerSpec* spec, uint8_t sliceIndex) {
47
+ assert (spec->dim >= spec->kvDim );
47
48
assert (spec->dim % spec->nSlices == 0 );
48
- kvDim = spec->kvDim ;
49
- dim0 = spec->dim / spec->nSlices ;
50
- assert (dim0 % 2 == 0 );
51
- dimOffset = dim0 * sliceIndex;
52
- size_t cacheBytes = spec->seqLen * dim0 * sizeof (float );
49
+ assert (spec->kvDim % spec->nSlices == 0 );
50
+
51
+ qDim0 = spec->dim / spec->nSlices ;
52
+ kvDim0 = spec->kvDim / spec->nSlices ;
53
+ assert (qDim0 % 2 == 0 );
54
+ assert (kvDim0 % 2 == 0 );
55
+ int kvDim0From = kvDim0 * sliceIndex;
56
+ int qDim0From = qDim0 * sliceIndex;
57
+ int qDim0To = qDim0From + qDim0;
58
+ qOffset = qDim0From - kvDim0From;
59
+ cacheDim = qDim0To - kvDim0From;
60
+ assert (cacheDim % 2 == 0 );
61
+
62
+ size_t cacheBytes = spec->seqLen * cacheDim * sizeof (float );
53
63
cache = (float *)NEW_BUFFER (cacheBytes);
64
+ printf (" 🕒 ropeCache: %ld bytes\n " , cacheBytes);
54
65
55
66
for (pos_t pos = 0 ; pos < spec->seqLen ; pos++) {
56
- for (int i = 0 ; i < dim0 ; i += 2 ) {
57
- int headDim = (i + dimOffset) % spec->headSize ;
67
+ for (int i = kvDim0From ; i < qDim0To ; i += 2 ) {
68
+ int headDim = i % spec->headSize ;
58
69
float freq = 1 .0f / powf (spec->ropeTheta , headDim / (float )spec->headSize );
59
70
float val = pos * freq;
60
71
float fcr = cosf (val);
61
72
float fci = sinf (val);
62
- cache[pos * dim0 + i ] = fcr;
63
- cache[pos * dim0 + i + 1 ] = fci;
73
+ cache[pos * cacheDim + (i - kvDim0From) ] = fcr;
74
+ cache[pos * cacheDim + (i - kvDim0From) + 1 ] = fci;
64
75
}
65
76
}
66
77
}
@@ -69,25 +80,23 @@ RopeSlice::~RopeSlice() {
69
80
FREE_BUFFER (cache);
70
81
}
71
82
72
- void RopeSlice::forward (float * q, float * k, pos_t pos, unsigned int nThreads, unsigned int threadIndex) {
73
- int halfDim0 = dim0 / 2 ;
74
- int slice = halfDim0 / nThreads;
83
+ void RopeSlice::forward (bool isQ, float * qOrV, pos_t pos, unsigned int nThreads, unsigned int threadIndex) {
84
+ int d0 = isQ ? qDim0 : kvDim0;
85
+ int offset = isQ ? qOffset : 0 ;
86
+ int halfD0 = d0 / 2 ;
87
+ int slice = halfD0 / nThreads;
75
88
int iStart = threadIndex * slice;
76
- int iEnd = (nThreads - 1 == threadIndex) ? halfDim0 : (iStart + slice);
89
+ int iEnd = (nThreads - 1 == threadIndex) ? halfD0 : (iStart + slice);
77
90
iStart *= 2 ;
78
91
iEnd *= 2 ;
79
92
80
93
for (int i = iStart; i < iEnd; i += 2 ) {
81
- float fcr = cache[pos * dim0 + i];
82
- float fci = cache[pos * dim0 + i + 1 ];
83
- int rotn = (dimOffset + i) < kvDim ? 2 : 1 ; // how many vectors? 2 = q & k, 1 = q only
84
- for (int _v = 0 ; _v < rotn; _v++) {
85
- float * vec = _v == 0 ? q : k; // the vector to rotate (query or key)
86
- float v0 = vec[i];
87
- float v1 = vec[i+1 ];
88
- vec[i] = v0 * fcr - v1 * fci;
89
- vec[i+1 ] = v0 * fci + v1 * fcr;
90
- }
94
+ float fcr = cache[pos * cacheDim + offset + i];
95
+ float fci = cache[pos * cacheDim + offset + i + 1 ];
96
+ float v0 = qOrV[i];
97
+ float v1 = qOrV[i+1 ];
98
+ qOrV[i] = v0 * fcr - v1 * fci;
99
+ qOrV[i+1 ] = v0 * fci + v1 * fcr;
91
100
}
92
101
}
93
102
0 commit comments