Skip to content

Commit 8b1cf89

Browse files
authored
feat: mesh topology, distributed all layers. (#136)
1 parent 0d1121e commit 8b1cf89

24 files changed

+792
-938
lines changed

.github/workflows/main.yml

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ jobs:
3333
make tokenizer-test
3434
make commands-test
3535
make llama2-tasks-test
36-
make grok1-tasks-test
3736
- name: funcs-test
3837
run: ./funcs-test
3938
- name: quants-test
@@ -44,8 +43,6 @@ jobs:
4443
run: ./commands-test
4544
- name: llama2-tasks-test
4645
run: ./llama2-tasks-test
47-
- name: grok1-tasks-test
48-
run: ./grok1-tasks-test
4946

5047
build-windows:
5148
name: Windows
@@ -66,7 +63,6 @@ jobs:
6663
make tokenizer-test
6764
make commands-test
6865
make llama2-tasks-test
69-
make grok1-tasks-test
7066
- name: funcs-test
7167
run: ./funcs-test
7268
- name: quants-test
@@ -77,5 +73,3 @@ jobs:
7773
run: ./commands-test
7874
- name: llama2-tasks-test
7975
run: ./llama2-tasks-test
80-
- name: grok1-tasks-test
81-
run: ./grok1-tasks-test

Makefile

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,19 +26,17 @@ tasks: src/tasks.cpp
2626
$(CXX) $(CXXFLAGS) -c src/tasks.cpp -o tasks.o
2727
llama2-tasks: src/llama2-tasks.cpp
2828
$(CXX) $(CXXFLAGS) -c src/llama2-tasks.cpp -o llama2-tasks.o
29-
grok1-tasks: src/grok1-tasks.cpp
30-
$(CXX) $(CXXFLAGS) -c src/grok1-tasks.cpp -o grok1-tasks.o
3129
mixtral-tasks: src/mixtral-tasks.cpp
3230
$(CXX) $(CXXFLAGS) -c src/mixtral-tasks.cpp -o mixtral-tasks.o
3331
tokenizer: src/tokenizer.cpp
3432
$(CXX) $(CXXFLAGS) -c src/tokenizer.cpp -o tokenizer.o
3533
app: src/app.cpp
3634
$(CXX) $(CXXFLAGS) -c src/app.cpp -o app.o
3735

38-
dllama: src/apps/dllama/dllama.cpp utils quants funcs commands socket transformer tasks llama2-tasks grok1-tasks mixtral-tasks tokenizer app
39-
$(CXX) $(CXXFLAGS) src/apps/dllama/dllama.cpp -o dllama utils.o quants.o funcs.o commands.o socket.o transformer.o tasks.o llama2-tasks.o grok1-tasks.o mixtral-tasks.o tokenizer.o app.o $(LIBS)
40-
dllama-api: src/apps/dllama-api/dllama-api.cpp utils quants funcs commands socket transformer tasks llama2-tasks grok1-tasks mixtral-tasks tokenizer app
41-
$(CXX) $(CXXFLAGS) src/apps/dllama-api/dllama-api.cpp -o dllama-api utils.o quants.o funcs.o commands.o socket.o transformer.o tasks.o llama2-tasks.o grok1-tasks.o mixtral-tasks.o tokenizer.o app.o $(LIBS)
36+
dllama: src/apps/dllama/dllama.cpp utils quants funcs commands socket transformer tasks llama2-tasks mixtral-tasks tokenizer app
37+
$(CXX) $(CXXFLAGS) src/apps/dllama/dllama.cpp -o dllama utils.o quants.o funcs.o commands.o socket.o transformer.o tasks.o llama2-tasks.o mixtral-tasks.o tokenizer.o app.o $(LIBS)
38+
dllama-api: src/apps/dllama-api/dllama-api.cpp utils quants funcs commands socket transformer tasks llama2-tasks mixtral-tasks tokenizer app
39+
$(CXX) $(CXXFLAGS) src/apps/dllama-api/dllama-api.cpp -o dllama-api utils.o quants.o funcs.o commands.o socket.o transformer.o tasks.o llama2-tasks.o mixtral-tasks.o tokenizer.o app.o $(LIBS)
4240
socket-benchmark: src/apps/socket-benchmark/socket-benchmark.cpp socket
4341
$(CXX) $(CXXFLAGS) src/apps/socket-benchmark/socket-benchmark.cpp -o socket-benchmark socket.o $(LIBS)
4442

@@ -52,5 +50,3 @@ commands-test: src/commands-test.cpp funcs commands utils quants transformer soc
5250
$(CXX) $(CXXFLAGS) src/commands-test.cpp -o commands-test funcs.o commands.o utils.o quants.o transformer.o socket.o $(LIBS)
5351
llama2-tasks-test: src/llama2-tasks-test.cpp utils quants funcs commands socket transformer tasks llama2-tasks tokenizer
5452
$(CXX) $(CXXFLAGS) src/llama2-tasks-test.cpp -o llama2-tasks-test utils.o quants.o funcs.o commands.o socket.o transformer.o tasks.o llama2-tasks.o tokenizer.o $(LIBS)
55-
grok1-tasks-test: src/grok1-tasks-test.cpp utils quants funcs commands socket transformer tasks llama2-tasks grok1-tasks tokenizer
56-
$(CXX) $(CXXFLAGS) src/grok1-tasks-test.cpp -o grok1-tasks-test utils.o quants.o funcs.o commands.o socket.o transformer.o tasks.o llama2-tasks.o grok1-tasks.o tokenizer.o $(LIBS)

converter/requirements.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
numpy==1.23.5
2-
torch==2.0.1
3-
safetensors==0.4.2
2+
pytorch==2.0.1
3+
safetensors==0.4.2
4+
sentencepiece==0.1.99

src/app.cpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,7 @@ AppArgs AppArgs::parse(int argc, char** argv, bool hasMode) {
4242
args.seed = (unsigned long long)time(NULL);
4343
args.chatTemplateType = TEMPLATE_UNKNOWN;
4444
args.maxSeqLen = 0;
45-
args.useDiscForKvCache = false;
46-
45+
args.packetAlignment = 0;
4746
int i = 1;
4847
if (hasMode && argc > 1) {
4948
args.mode = argv[1];
@@ -102,8 +101,8 @@ AppArgs AppArgs::parse(int argc, char** argv, bool hasMode) {
102101
args.chatTemplateType = parseChatTemplateType(value);
103102
} else if (strcmp(name, "--max-seq-len") == 0) {
104103
args.maxSeqLen = (unsigned int)atoi(value);
105-
} else if (strcmp(name, "--kv-cache-storage") == 0) {
106-
args.useDiscForKvCache = strcmp(value, "disc") == 0;
104+
} else if (strcmp(name, "--packet-alignment") == 0) {
105+
args.packetAlignment = (size_t)atoi(value);
107106
} else {
108107
printf("Unknown option %s\n", name);
109108
exit(EXIT_FAILURE);
@@ -114,8 +113,6 @@ AppArgs AppArgs::parse(int argc, char** argv, bool hasMode) {
114113

115114
TransformerArch TransformerArchFactory::create(TransformerSpec* spec) {
116115
if (spec->archType == LLAMA) return buildLlamaArch(spec);
117-
if (spec->archType == GROK1) return buildGrok1Arch(spec);
118-
if (spec->archType == MIXTRAL) return buildMixtralArch(spec);
119116
printf("Unsupported arch type: %d\n", spec->archType);
120117
exit(EXIT_FAILURE);
121118
}
@@ -128,7 +125,7 @@ void App::run(AppArgs* args, void (*program)(Inference* inference, SocketPool* s
128125
throw std::runtime_error("Tokenizer is required");
129126
}
130127

131-
SocketPool* socketPool = SocketPool::connect(args->nWorkers, args->workerHosts, args->workerPorts);
128+
SocketPool* socketPool = SocketPool::connect(args->nWorkers, args->workerHosts, args->workerPorts, args->packetAlignment);
132129
unsigned int nSlices = args->nWorkers + 1;
133130

134131
TransformerSpec spec = Transformer::loadSpecFromFile(args->modelPath, nSlices, args->maxSeqLen, args->weightsFloatType, args->bufferFloatType);
@@ -140,7 +137,6 @@ void App::run(AppArgs* args, void (*program)(Inference* inference, SocketPool* s
140137
}
141138

142139
TransformerConfig config;
143-
config.useDiscForKvCache = args->useDiscForKvCache;
144140

145141
Transformer transformer = Transformer::loadRootFromFile(args->modelPath, &spec, &config, socketPool);
146142
socketPool->setTurbo(true);

src/app.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,14 @@
99
#include "transformer.hpp"
1010
#include "tasks.hpp"
1111
#include "llama2-tasks.hpp"
12-
#include "grok1-tasks.hpp"
1312
#include "mixtral-tasks.hpp"
1413
#include "tokenizer.hpp"
1514

1615
class AppArgs {
1716
public:
1817
char* mode;
1918
int nThreads;
20-
bool useDiscForKvCache;
19+
size_t packetAlignment;
2120

2221
// inference
2322
char* modelPath;

src/apps/dllama-api/dllama-api.cpp

Lines changed: 49 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,10 @@ enum class HttpMethod {
3737

3838
class HttpRequest {
3939
public:
40-
static HttpRequest read(Socket& socket) {
41-
HttpRequest req(&socket);
40+
static HttpRequest read(int serverSocket) {
41+
HttpRequest req(serverSocket);
4242

43-
std::vector<char> httpRequest = socket.readHttpRequest();
43+
std::vector<char> httpRequest = req.readHttpRequest();
4444
// Parse the HTTP request
4545
std::string data = std::string(httpRequest.begin(), httpRequest.end());
4646

@@ -89,16 +89,48 @@ class HttpRequest {
8989
}
9090

9191
private:
92-
Socket* socket;
92+
int serverSocket;
9393
public:
9494
std::string path;
9595
std::unordered_map<std::string, std::string> headers;
9696
std::string body;
9797
json parsedJson;
9898
HttpMethod method;
9999

100-
HttpRequest(Socket* socket) {
101-
this->socket = socket;
100+
HttpRequest(int serverSocket) {
101+
this->serverSocket = serverSocket;
102+
}
103+
104+
std::vector<char> readHttpRequest() {
105+
std::vector<char> httpRequest;
106+
char buffer[1024 * 1024]; // TODO: this should be refactored asap
107+
ssize_t bytesRead;
108+
109+
// Peek into the socket buffer to check available data
110+
bytesRead = recv(serverSocket, buffer, sizeof(buffer), MSG_PEEK);
111+
if (bytesRead <= 0) {
112+
// No data available or error occurred
113+
if (bytesRead == 0) {
114+
// No more data to read
115+
return httpRequest;
116+
} else {
117+
// Error while peeking
118+
throw std::runtime_error("Error while peeking into socket");
119+
}
120+
}
121+
122+
// Resize buffer according to the amount of data available
123+
std::vector<char> peekBuffer(bytesRead);
124+
bytesRead = recv(serverSocket, peekBuffer.data(), bytesRead, 0);
125+
if (bytesRead <= 0) {
126+
// Error while reading
127+
throw std::runtime_error("Error while reading from socket");
128+
}
129+
130+
// Append data to httpRequest
131+
httpRequest.insert(httpRequest.end(), peekBuffer.begin(), peekBuffer.end());
132+
133+
return httpRequest;
102134
}
103135

104136
std::string getMethod() {
@@ -111,7 +143,7 @@ class HttpRequest {
111143

112144
void writeNotFound() {
113145
const char* data = "HTTP/1.1 404 Not Found\r\n";
114-
socket->write(data, strlen(data));
146+
writeSocket(serverSocket, data, strlen(data));
115147
}
116148

117149
void writeJson(std::string json) {
@@ -120,7 +152,7 @@ class HttpRequest {
120152
<< "Content-Type: application/json; charset=utf-8\r\n"
121153
<< "Content-Length: " << json.length() << "\r\n\r\n" << json;
122154
std::string data = buffer.str();
123-
socket->write(data.c_str(), data.size());
155+
writeSocket(serverSocket, data.c_str(), data.size());
124156
}
125157

126158
void writeStreamStartChunk() {
@@ -130,19 +162,19 @@ class HttpRequest {
130162
<< "Connection: close\r\n"
131163
<< "Transfer-Encoding: chunked\r\n\r\n";
132164
std::string data = buffer.str();
133-
socket->write(data.c_str(), data.size());
165+
writeSocket(serverSocket, data.c_str(), data.size());
134166
}
135167

136168
void writeStreamChunk(const std::string data) {
137169
std::ostringstream buffer;
138170
buffer << std::hex << data.size() << "\r\n" << data << "\r\n";
139171
std::string d = buffer.str();
140-
socket->write(d.c_str(), d.size());
172+
writeSocket(serverSocket, d.c_str(), d.size());
141173
}
142174

143175
void writeStreamEndChunk() {
144176
const char* endChunk = "0000\r\n\r\n";
145-
socket->write(endChunk, strlen(endChunk));
177+
writeSocket(serverSocket, endChunk, strlen(endChunk));
146178
}
147179
};
148180

@@ -260,9 +292,6 @@ class ApiServer {
260292
std::vector<ChatMessage> deltaPrompt = params.messages;
261293
naiveCache.resolveDeltaPrompt(deltaPrompt, startPos);
262294

263-
printf("🔸");
264-
fflush(stdout);
265-
266295
size_t nInputItems = deltaPrompt.size();
267296
ChatItem inputItems[nInputItems];
268297
for (size_t i = 0; i < nInputItems; i++) {
@@ -271,6 +300,8 @@ class ApiServer {
271300
}
272301

273302
std::string inputPrompt = chatTemplate->generate(nInputItems, inputItems, true);
303+
printf("🔹%s🔸", inputPrompt.c_str());
304+
274305
int promptLength = inputPrompt.size();
275306
int nPromptTokens;
276307
int promptTokens[promptLength + 3];
@@ -393,7 +424,7 @@ void handleModelsRequest(HttpRequest& request) {
393424
}
394425

395426
void server(Inference* inference, SocketPool* socketPool, Tokenizer *tokenizer, Sampler *sampler, AppArgs* args, TransformerSpec* spec) {
396-
SocketServer* server = new SocketServer(args->port);
427+
int serverSocket = createServerSocket(args->port);
397428

398429
TokenizerChatStops stops(tokenizer);
399430
ChatTemplate chatTemplate(args->chatTemplateType, tokenizer->chatTemplate, stops.stops[0]);
@@ -417,8 +448,8 @@ void server(Inference* inference, SocketPool* socketPool, Tokenizer *tokenizer,
417448

418449
while (true) {
419450
try {
420-
Socket client = server->accept();
421-
HttpRequest request = HttpRequest::read(client);
451+
int clientSocket = acceptSocket(serverSocket);
452+
HttpRequest request = HttpRequest::read(clientSocket);
422453
printf("🔷 %s %s\n", request.getMethod().c_str(), request.path.c_str());
423454
Router::resolve(request, routes);
424455
} catch (ReadSocketException& ex) {
@@ -428,7 +459,7 @@ void server(Inference* inference, SocketPool* socketPool, Tokenizer *tokenizer,
428459
}
429460
}
430461

431-
delete server;
462+
closeServerSocket(serverSocket);
432463
}
433464

434465
int main(int argc, char *argv[]) {

src/apps/dllama/dllama.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -208,16 +208,16 @@ void worker(AppArgs* args) {
208208
}
209209

210210
TransformerConfig config;
211-
config.useDiscForKvCache = args->useDiscForKvCache;
212211

213-
SocketServer server(args->port);
214-
Socket socket = server.accept();
212+
SocketPool* socketPool = SocketPool::serve(args->port);
215213
TransformerSpec spec;
216-
Transformer transformer = Transformer::loadSlice(&spec, &config, &socket);
214+
Transformer transformer = Transformer::loadSlice(&spec, &config, socketPool);
217215
TransformerArch arch = TransformerArchFactory::create(&spec);
218216

219-
Worker worker = Worker(&arch, args->nThreads, &transformer, &socket);
217+
Worker worker = Worker(&arch, args->nThreads, &transformer, socketPool);
220218
worker.work();
219+
220+
delete socketPool;
221221
}
222222

223223
int main(int argc, char *argv[]) {

src/apps/socket-benchmark/socket-benchmark.cpp

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#include <sys/socket.h>
66
#include <arpa/inet.h>
77
#include <fcntl.h>
8+
#include <stdexcept>
9+
#include <cassert>
810

911
using namespace std::chrono;
1012

@@ -15,14 +17,6 @@ unsigned int nAttempts = 5000;
1517
int port = 7721;
1618
bool testTcp = true;
1719

18-
void setNonBlocking(int socket) {
19-
//int flags = fcntl(socket, F_GETFL, 0);
20-
//if (fcntl(socket, F_SETFL, flags |= O_NONBLOCK) < 0)
21-
// throw std::runtime_error("Cannot set socket flags");
22-
}
23-
24-
#define MAX_PACKAGE_SIZE 1280
25-
2620
char pktinfo[4096] = {0};
2721

2822
void readUdpSocket(int socket, char* buffer, unsigned int size, struct sockaddr_in* clientAddr, socklen_t* clientAddrLen) {
@@ -88,8 +82,9 @@ void server() {
8882
if (testTcp) {
8983
printf("TCP test\n");
9084

91-
SocketServer server(port);
92-
Socket socket = server.accept();
85+
SocketPool* pool = SocketPool::serve(port);
86+
assert(pool->nSockets == 1);
87+
9388
for (long i = 0; i < packageSizesCount; i++) {
9489
unsigned int currentPackageSize = packageSizes[i];
9590

@@ -98,9 +93,9 @@ void server() {
9893
long long totalTime = 0; // [us]
9994
for (long a = 0; a < nAttempts; a++) {
10095
auto t0 = high_resolution_clock::now();
101-
socket.read(buffer, currentPackageSize);
96+
pool->read(0, buffer, currentPackageSize);
10297
auto t1 = high_resolution_clock::now();
103-
socket.write(buffer, currentPackageSize);
98+
pool->write(0, buffer, currentPackageSize);
10499
auto t2 = high_resolution_clock::now();
105100

106101
totalReadTime += duration_cast<microseconds>(t1 - t0).count();
@@ -127,7 +122,6 @@ void server() {
127122
serverAddr.sin_family = AF_INET;
128123
serverAddr.sin_addr.s_addr = INADDR_ANY;
129124
serverAddr.sin_port = htons(port);
130-
setNonBlocking(serverSocket);
131125

132126
if (bind(serverSocket, (struct sockaddr *)&serverAddr, sizeof(serverAddr)) < 0)
133127
throw std::runtime_error("Cannot bind socket");
@@ -176,7 +170,7 @@ void client(char* host) {
176170
int* ports = new int[1];
177171
ports[0] = port;
178172

179-
SocketPool* pool = SocketPool::connect(1, hosts, ports);
173+
SocketPool* pool = SocketPool::connect(1, hosts, ports, 0);
180174
pool->setTurbo(true);
181175

182176
for (long i = 0; i < packageSizesCount; i++) {
@@ -216,7 +210,6 @@ void client(char* host) {
216210
serverAddr.sin_family = AF_INET;
217211
serverAddr.sin_port = htons(port);
218212
serverAddr.sin_addr.s_addr = inet_addr(host);
219-
setNonBlocking(clientSocket);
220213

221214
for (long i = 0; i < packageSizesCount; i++) {
222215
unsigned int currentPackageSize = packageSizes[i];

src/commands.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ class MatmulCommand {
8282
unsigned int n;
8383
unsigned int d;
8484
size_t cpuSize;
85-
void* cpuWeights;
8685
public:
86+
void* cpuWeights;
8787
MatmulCommand(const unsigned int n, const unsigned int d, const FloatType inputFloatType, const FloatType weightsFloatType);
8888
~MatmulCommand();
8989
size_t loadWeights(const void* source);

0 commit comments

Comments
 (0)