Skip to content

Commit d5e8f89

Browse files
authored
sync pos. (#37)
1 parent b2f3450 commit d5e8f89

10 files changed

+35
-46
lines changed

src/grok1-tasks.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ void grokMulInput(TASK_ARGS) {
1414
}
1515

1616
// source: https://github.com/karpathy/llama2.c/pull/408
17-
void ropeFalcon(float* q, float* k, TransformerSpec* spec, int pos, float theta) {
17+
void ropeFalcon(float* q, float* k, TransformerSpec* spec, pos_t pos, float theta) {
1818
for (int i = 0; i < spec->nHeads; i++) {
1919
for (int j = 0; j < spec->headSize / 2; j++) {
2020
float freq = 1.0f / powf(theta, 2.0f * (float)j / (float)spec->headSize);
@@ -301,7 +301,7 @@ TransformerArch buildGrok1Arch(TransformerSpec* spec) {
301301

302302
// inference
303303

304-
a.I(sendPoke, TASK_TYPE_TRANSFER);
304+
a.I(sendPos, TASK_TYPE_TRANSFER);
305305
a.I(grokMulInput, TASK_TYPE_INFERENCE);
306306
for (int i = 0; i < spec->nLayers; i++) {
307307
a.I(llamaRmsAtt, TASK_TYPE_INFERENCE);

src/llama2-tasks.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ TransformerArch buildLlama2Arch(TransformerSpec* spec) {
286286

287287
// inference
288288

289-
a.I(sendPoke, TASK_TYPE_TRANSFER);
289+
a.I(sendPos, TASK_TYPE_TRANSFER);
290290
for (int i = 0; i < spec->nLayers; i++) {
291291
a.I(llamaRmsAtt, TASK_TYPE_INFERENCE);
292292
a.I(llamaRmsAttNorm, TASK_TYPE_INFERENCE);

src/main.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ void generate(Inference* inference, SocketPool* socketPool, Tokenizer *tokenizer
6666
long start = 0; // used to time our code, only initialized after first iteration
6767
int next; // will store the next token in the sequence
6868
int token = promptTokens[0]; // kick off with the first token in the prompt
69-
int pos = 0; // position in the sequence
69+
pos_t pos = 0; // position in the sequence
7070

7171
unsigned long inferenceTime;
7272
unsigned long transferTime;
@@ -139,7 +139,7 @@ void chat(Inference* inference, SocketPool* socketPool, Tokenizer* tokenizer, Sa
139139
int next; // will store the next token in the sequence
140140
int token; // stores the current token to feed into the transformer
141141
int prev_token;
142-
int pos = 0; // position in the sequence
142+
pos_t pos = 0; // position in the sequence
143143
while (pos < args->steps) {
144144
// when it is the user's turn to contribute tokens to the dialog...
145145
if (userTurn) {
@@ -236,9 +236,9 @@ void simpleServer(Inference* inference, SocketPool* socketPool, Tokenizer *token
236236
tokenizer->encode(prompt, promptTokens, &nPromptTokens, true, false);
237237

238238
int token = promptTokens[0];
239-
int maxPos = nPromptTokens + maxTokens;
239+
pos_t maxPos = nPromptTokens + maxTokens;
240240
if (maxPos > spec->seqLen) maxPos = spec->seqLen;
241-
for (int pos = 0; pos < maxPos; pos++) {
241+
for (pos_t pos = 0; pos < maxPos; pos++) {
242242
float* logits = inference->infer(token, pos);
243243

244244
if (pos < nPromptTokens - 1) {

src/mixtral-tasks.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ TransformerArch buildMixtralArch(TransformerSpec* spec) {
77

88
// inference
99

10-
a.I(sendPoke, TASK_TYPE_TRANSFER);
10+
a.I(sendPos, TASK_TYPE_TRANSFER);
1111
for (int i = 0; i < spec->nLayers; i++) {
1212
a.I(llamaRmsAtt, TASK_TYPE_INFERENCE);
1313
a.I(llamaRmsAttNorm, TASK_TYPE_INFERENCE);

src/socket.cpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ static inline void setNoDelay(int socket) {
3434
throw std::runtime_error("Error setting socket to no-delay");
3535
}
3636

37-
static inline void writeSocket(int socket, const char* data, size_t size) {
37+
static inline void writeSocket(int socket, const void* data, size_t size) {
3838
while (size > 0) {
3939
int s = send(socket, (char*)data, size, 0);
4040
if (s < 0) {
@@ -50,7 +50,7 @@ static inline void writeSocket(int socket, const char* data, size_t size) {
5050
}
5151
}
5252

53-
static inline void readSocket(bool* isNonBlocking, int socket, char* data, size_t size) {
53+
static inline void readSocket(bool* isNonBlocking, int socket, void* data, size_t size) {
5454
unsigned int attempt = 0;
5555
time_t startTime;
5656
while (size > 0) {
@@ -136,13 +136,13 @@ SocketPool::~SocketPool() {
136136
delete[] isNonBlocking;
137137
}
138138

139-
void SocketPool::write(unsigned int socketIndex, const char* data, size_t size) {
139+
void SocketPool::write(unsigned int socketIndex, const void* data, size_t size) {
140140
assert(socketIndex >= 0 && socketIndex < nSockets);
141141
sentBytes += size;
142142
writeSocket(sockets[socketIndex], data, size);
143143
}
144144

145-
void SocketPool::read(unsigned int socketIndex, char* data, size_t size) {
145+
void SocketPool::read(unsigned int socketIndex, void* data, size_t size) {
146146
assert(socketIndex >= 0 && socketIndex < nSockets);
147147
recvBytes += size;
148148
readSocket(&isNonBlocking[socketIndex], sockets[socketIndex], data, size);
@@ -236,11 +236,11 @@ Socket::~Socket() {
236236
close(socket);
237237
}
238238

239-
void Socket::write(const char* data, size_t size) {
239+
void Socket::write(const void* data, size_t size) {
240240
writeSocket(socket, data, size);
241241
}
242242

243-
void Socket::read(char* data, size_t size) {
243+
void Socket::read(void* data, size_t size) {
244244
readSocket(&isNonBlocking, socket, data, size);
245245
}
246246

src/socket.hpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class WriteSocketException : public std::exception {
2020

2121
struct SocketIo {
2222
unsigned int socketIndex;
23-
const char* data;
23+
const void* data;
2424
size_t size;
2525
};
2626

@@ -39,8 +39,8 @@ class SocketPool {
3939
SocketPool(unsigned int nSockets, int* sockets);
4040
~SocketPool();
4141

42-
void write(unsigned int socketIndex, const char* data, size_t size);
43-
void read(unsigned int socketIndex, char* data, size_t size);
42+
void write(unsigned int socketIndex, const void* data, size_t size);
43+
void read(unsigned int socketIndex, void* data, size_t size);
4444
void writeMany(unsigned int n, SocketIo* ios);
4545
void readMany(unsigned int n, SocketIo* ios);
4646
void getStats(size_t* sentBytes, size_t* recvBytes);
@@ -55,8 +55,8 @@ class Socket {
5555
Socket(int socket);
5656
~Socket();
5757

58-
void write(const char* data, size_t size);
59-
void read(char* data, size_t size);
58+
void write(const void* data, size_t size);
59+
void read(void* data, size_t size);
6060
};
6161

6262
class SocketServer {

src/tasks.cpp

+8-12
Original file line numberDiff line numberDiff line change
@@ -160,27 +160,23 @@ void dequantizeSlicedBuffer(unsigned int nThreads, unsigned int threadIndex, Tra
160160
}
161161
}
162162

163-
void sendPoke(TASK_ARGS) {
163+
void sendPos(TASK_ARGS) {
164164
TASK_VARIABLES;
165165

166166
if (ctx->socketPool != NULL) {
167-
const char poke = 0x25;
168-
169167
unsigned int nSockets = ctx->socketPool->nSockets / nThreads + (ctx->socketPool->nSockets % nThreads > threadIndex ? 1 : 0);
170168
SocketIo ios[nSockets];
171169
for (int i = 0; i < nSockets; i++) {
172170
ios[i].socketIndex = threadIndex + i * nThreads;
173-
ios[i].data = &poke;
174-
ios[i].size = sizeof(char);
171+
ios[i].data = &transformer->pos;
172+
ios[i].size = sizeof(pos_t);
175173
}
176174
ctx->socketPool->writeMany(nSockets, ios);
177175
}
178176
}
179177

180-
void waitForPoke(Socket* socket) {
181-
char poke;
182-
socket->read(&poke, sizeof(char));
183-
assert(poke == 0x25);
178+
void waitForPos(Transformer* transformer, Socket* socket) {
179+
socket->read(&transformer->pos, sizeof(pos_t));
184180
}
185181

186182
Inference::Inference(TransformerArch* arch, unsigned int nThreads, Transformer* transformer, SocketPool* socketPool) {
@@ -190,15 +186,15 @@ Inference::Inference(TransformerArch* arch, unsigned int nThreads, Transformer*
190186
context.transformer = transformer;
191187
context.socket = NULL;
192188
context.socketPool = socketPool;
193-
assert(arch->inference.tasks[0].handler == sendPoke);
189+
assert(arch->inference.tasks[0].handler == sendPos);
194190
taskLoop = new TaskLoop(nThreads, arch->inference.nTasks, TASK_N_TYPES, arch->inference.tasks, (void*)&context);
195191
}
196192

197193
Inference::~Inference() {
198194
delete taskLoop;
199195
}
200196

201-
float* Inference::infer(int token, int pos) {
197+
float* Inference::infer(int token, pos_t pos) {
202198
transformer->pos = pos;
203199

204200
float* contentRow = ((float*)transformer->tokenEmbeddingTable) + token * transformer->spec->dim;
@@ -231,7 +227,7 @@ Worker::~Worker() {
231227

232228
void Worker::work() {
233229
while (true) {
234-
waitForPoke(socket);
230+
waitForPos(transformer, socket);
235231

236232
context.currentBlockIndex = 0;
237233
taskLoop->run();

src/tasks.hpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ void syncMissingSlicesOfSlicedBuffer(unsigned int nThreads, unsigned int threadI
4848
void quantizeUnitBuffer(unsigned int nThreads, unsigned int threadIndex, TransformerContext* ctx, uint8_t sourceBufferIndex, uint8_t targetBufferIndex);
4949
void quantizeSlicedBuffer(unsigned int nThreads, unsigned int threadIndex, TransformerContext* ctx, bool quantizeRootSlice, uint8_t sourceBufferIndex, uint8_t targetBufferIndex);
5050
void dequantizeSlicedBuffer(unsigned int nThreads, unsigned int threadIndex, TransformerContext* ctx, bool dequantizeRootSlice, uint8_t sourceBufferIndex, uint8_t targetBufferIndex);
51-
void sendPoke(TASK_ARGS);
51+
void sendPos(TASK_ARGS);
5252

5353
class Inference {
5454
private:
@@ -60,7 +60,7 @@ class Inference {
6060
public:
6161
Inference(TransformerArch* arch, unsigned int nThreads, Transformer* transformer, SocketPool* socketPool);
6262
~Inference();
63-
float* infer(int token, int pos);
63+
float* infer(int token, pos_t pos);
6464
void getStats(unsigned long* inferenceTime, unsigned long* transferTime);
6565
};
6666

src/transformer.cpp

+2-10
Original file line numberDiff line numberDiff line change
@@ -43,16 +43,8 @@ size_t MatmulSlice::splitWeights(uint8_t sliceIndex, char* weights, char* weight
4343
return copiedBytes;
4444
}
4545

46-
long MatmulSlice::mergeOutputs(uint8_t sliceIndex, float* output, float* output0) {
47-
long offset = this->d0 * sliceIndex;
48-
for (int i = 0; i < this->d0; i++) {
49-
output[offset + i] = output0[i];
50-
}
51-
return offset; // offset in floats
52-
}
53-
5446
void initRope(float* cache, TransformerSpec* spec) {
55-
for (int pos = 0; pos < spec->seqLen; pos++) {
47+
for (pos_t pos = 0; pos < spec->seqLen; pos++) {
5648
for (int i = 0; i < spec->dim; i += 2) {
5749
int head_dim = i % spec->headSize;
5850
float freq = 1.0f / powf(spec->ropeTheta, head_dim / (float)spec->headSize);
@@ -65,7 +57,7 @@ void initRope(float* cache, TransformerSpec* spec) {
6557
}
6658
}
6759

68-
void rope(float* cache, float* q, float* k, TransformerSpec* spec, int pos, unsigned int nThreads, unsigned int threadIndex) {
60+
void rope(float* cache, float* q, float* k, TransformerSpec* spec, pos_t pos, unsigned int nThreads, unsigned int threadIndex) {
6961
int halfDim = spec->dim / 2;
7062
int slice = halfDim / nThreads;
7163
int iStart = threadIndex * slice;

src/transformer.hpp

+4-3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
#include "quants.hpp"
77
#include "socket.hpp"
88

9+
typedef unsigned short pos_t;
10+
911
class MatmulSlice {
1012
public:
1113
FloatType type;
@@ -17,7 +19,6 @@ class MatmulSlice {
1719

1820
MatmulSlice(FloatType type, int nSlices, int n, int d);
1921
size_t splitWeights(uint8_t sliceIndex, char* weights, char* weights0);
20-
long mergeOutputs(uint8_t sliceIndex, float* output, float* output0);
2122
};
2223

2324
enum TransformerHeaderKey {
@@ -84,7 +85,7 @@ struct TransformerSpec {
8485
};
8586

8687
void initRope(float* cache, TransformerSpec* spec);
87-
void rope(float* cache, float* q, float* k, TransformerSpec* spec, int pos, unsigned int nThreads, unsigned int threadIndex);
88+
void rope(float* cache, float* q, float* k, TransformerSpec* spec, pos_t pos, unsigned int nThreads, unsigned int threadIndex);
8889

8990
class TransformerBlock {
9091
public:
@@ -185,8 +186,8 @@ class Transformer {
185186
size_t wclsBytes;
186187
char* wcls;
187188

189+
pos_t pos;
188190
float rms;
189-
int pos;
190191
float* x;
191192
float* logits;
192193
float* ropeCache;

0 commit comments

Comments
 (0)