Skip to content

Commit a4964a0

Browse files
authored
fix: fixed inference getting stuck (#166)
1 parent 1e73dcb commit a4964a0

File tree

2 files changed

+26
-24
lines changed

2 files changed

+26
-24
lines changed

src/app.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,8 +234,10 @@ void runInferenceApp(AppCliArgs *args, void (*handler)(AppInferenceContext *cont
234234

235235
RootLlmInference inference(&net, &cpu, &execution, &executor, network);
236236

237-
if (network != nullptr)
237+
if (network != nullptr) {
238238
network->resetStats();
239+
network->setTurbo(true);
240+
}
239241

240242
AppInferenceContext context;
241243
context.args = args;

src/nn/nn-network.cpp

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ typedef SSIZE_T ssize_t;
2121
#define SOCKET_LAST_ERROR strerror(errno)
2222

2323
#define ACK 23571113
24-
#define ONE_MB 1048576
24+
#define MAX_CHUNK_SIZE 4096
2525

2626
static inline bool isEagainError() {
2727
#ifdef _WIN32
@@ -338,11 +338,11 @@ std::unique_ptr<NnNetwork> NnNetwork::connect(NnSize nSockets, char **hosts, NnS
338338
return std::unique_ptr<NnNetwork>(new NnNetwork(nSockets, sockets));
339339
}
340340

341-
NnNetwork::NnNetwork(NnSize nSockets, int *sockets) {
341+
NnNetwork::NnNetwork(NnSize nSockets, int *sockets)
342+
: sentBytes(0), recvBytes(0)
343+
{
342344
this->nSockets = nSockets;
343345
this->sockets = sockets;
344-
this->sentBytes.exchange(0);
345-
this->recvBytes.exchange(0);
346346
}
347347

348348
NnNetwork::~NnNetwork() {
@@ -362,25 +362,25 @@ void NnNetwork::setTurbo(bool enabled) {
362362

363363
void NnNetwork::write(NnSize socketIndex, const void *data, size_t size) {
364364
assert(socketIndex >= 0 && socketIndex < nSockets);
365-
sentBytes += size;
365+
sentBytes.fetch_add(size);
366366

367367
char *current = (char*)data;
368368
int s = sockets[socketIndex];
369-
for (size_t chunk = 0; chunk < size; chunk += ONE_MB) {
370-
size_t chunkSize = chunk + ONE_MB < size ? ONE_MB : size - chunk;
369+
for (size_t chunk = 0; chunk < size; chunk += MAX_CHUNK_SIZE) {
370+
size_t chunkSize = chunk + MAX_CHUNK_SIZE < size ? MAX_CHUNK_SIZE : size - chunk;
371371
writeSocket(s, current, chunkSize);
372372
current += chunkSize;
373373
}
374374
}
375375

376376
void NnNetwork::read(NnSize socketIndex, void *data, size_t size) {
377377
assert(socketIndex >= 0 && socketIndex < nSockets);
378-
recvBytes += size;
378+
recvBytes.fetch_add(size);
379379

380380
char *current = (char*)data;
381381
int s = sockets[socketIndex];
382-
for (size_t chunk = 0; chunk < size; chunk += ONE_MB) {
383-
size_t chunkSize = chunk + ONE_MB < size ? ONE_MB : size - chunk;
382+
for (size_t chunk = 0; chunk < size; chunk += MAX_CHUNK_SIZE) {
383+
size_t chunkSize = chunk + MAX_CHUNK_SIZE < size ? MAX_CHUNK_SIZE : size - chunk;
384384
readSocket(s, current, chunkSize);
385385
current += chunkSize;
386386
}
@@ -399,7 +399,7 @@ void NnNetwork::readAck(NnSize socketIndex) {
399399
bool NnNetwork::tryReadWithMaxAttempts(NnSize socketIndex, void *data, size_t size, unsigned long maxAttempts) {
400400
assert(socketIndex >= 0 && socketIndex < nSockets);
401401
if (tryReadSocket(sockets[socketIndex], data, size, maxAttempts)) {
402-
recvBytes += size;
402+
recvBytes.fetch_add(size);
403403
return true;
404404
}
405405
return false;
@@ -420,7 +420,8 @@ void NnNetwork::writeMany(NnSize n, NnSocketIo *ios) {
420420
if (io->size > 0) {
421421
isWriting = true;
422422
int socket = sockets[io->socketIndex];
423-
ssize_t s = send(socket, (const char*)io->data, io->size, 0);
423+
ssize_t chunkSize = io->size > MAX_CHUNK_SIZE ? MAX_CHUNK_SIZE : io->size;
424+
ssize_t s = send(socket, (const char*)io->data, chunkSize, 0);
424425
if (s < 0) {
425426
if (isEagainError()) {
426427
continue;
@@ -434,7 +435,7 @@ void NnNetwork::writeMany(NnSize n, NnSocketIo *ios) {
434435
}
435436
}
436437
} while (isWriting);
437-
sentBytes += nBytes;
438+
sentBytes.fetch_add(nBytes);
438439
}
439440

440441
void NnNetwork::writeAll(void *data, size_t size) {
@@ -477,18 +478,18 @@ void NnNetwork::readMany(NnSize n, NnSocketIo *ios) {
477478
}
478479
}
479480
} while (isReading);
480-
recvBytes += nBytes;
481+
recvBytes.fetch_add(nBytes);
481482
}
482483

483484
void NnNetwork::getStats(size_t *sentBytes, size_t *recvBytes) {
484-
*sentBytes = this->sentBytes;
485-
*recvBytes = this->recvBytes;
486-
this->resetStats();
485+
*sentBytes = this->sentBytes.load();
486+
*recvBytes = this->recvBytes.load();
487+
resetStats();
487488
}
488489

489490
void NnNetwork::resetStats() {
490-
this->sentBytes.exchange(0);
491-
this->recvBytes.exchange(0);
491+
sentBytes.exchange(0);
492+
recvBytes.exchange(0);
492493
}
493494

494495
static void syncWithRoot(NnNetwork *network, NnByte nodeIndex, NnByte *buffer, NnSize nBytes, NnSize nThreads, NnSize threadIndex) {
@@ -525,8 +526,7 @@ static void syncNodeSlices(bool onlyFromWorkerToRoot, NnNetwork *network, NnSize
525526
if (nSocketsPerThread == 0) return;
526527
NnSize sliceBytes = nBytes / nNodes;
527528

528-
std::unique_ptr<NnSocketIo> iosPtr(new NnSocketIo[nSocketsPerThread]);
529-
NnSocketIo *ios = iosPtr.get();
529+
std::vector<NnSocketIo> ios(nSocketsPerThread);
530530

531531
if (!onlyFromWorkerToRoot || isWorker) {
532532
NnByte *mySliceData = &buffer[sliceBytes * nodeIndex];
@@ -537,7 +537,7 @@ static void syncNodeSlices(bool onlyFromWorkerToRoot, NnNetwork *network, NnSize
537537
ios[i].data = mySliceData;
538538
ios[i].size = sliceBytes;
539539
}
540-
network->writeMany(nSocketsPerThread, ios);
540+
network->writeMany(nSocketsPerThread, &ios[0]);
541541
}
542542

543543
if (!onlyFromWorkerToRoot || !isWorker) {
@@ -549,7 +549,7 @@ static void syncNodeSlices(bool onlyFromWorkerToRoot, NnNetwork *network, NnSize
549549
ios[i].data = sliceData;
550550
ios[i].size = sliceBytes;
551551
}
552-
network->readMany(nSocketsPerThread, ios);
552+
network->readMany(nSocketsPerThread, &ios[0]);
553553
}
554554
}
555555

0 commit comments

Comments
 (0)