@@ -21,7 +21,7 @@ typedef SSIZE_T ssize_t;
21
21
#define SOCKET_LAST_ERROR strerror (errno)
22
22
23
23
#define ACK 23571113
24
- #define ONE_MB 1048576
24
+ #define MAX_CHUNK_SIZE 4096
25
25
26
26
static inline bool isEagainError () {
27
27
#ifdef _WIN32
@@ -338,11 +338,11 @@ std::unique_ptr<NnNetwork> NnNetwork::connect(NnSize nSockets, char **hosts, NnS
338
338
return std::unique_ptr<NnNetwork>(new NnNetwork (nSockets, sockets));
339
339
}
340
340
341
- NnNetwork::NnNetwork (NnSize nSockets, int *sockets) {
341
+ NnNetwork::NnNetwork (NnSize nSockets, int *sockets)
342
+ : sentBytes(0 ), recvBytes(0 )
343
+ {
342
344
this ->nSockets = nSockets;
343
345
this ->sockets = sockets;
344
- this ->sentBytes .exchange (0 );
345
- this ->recvBytes .exchange (0 );
346
346
}
347
347
348
348
NnNetwork::~NnNetwork () {
@@ -362,25 +362,25 @@ void NnNetwork::setTurbo(bool enabled) {
362
362
363
363
void NnNetwork::write (NnSize socketIndex, const void *data, size_t size) {
364
364
assert (socketIndex >= 0 && socketIndex < nSockets);
365
- sentBytes += size;
365
+ sentBytes. fetch_add ( size) ;
366
366
367
367
char *current = (char *)data;
368
368
int s = sockets[socketIndex];
369
- for (size_t chunk = 0 ; chunk < size; chunk += ONE_MB ) {
370
- size_t chunkSize = chunk + ONE_MB < size ? ONE_MB : size - chunk;
369
+ for (size_t chunk = 0 ; chunk < size; chunk += MAX_CHUNK_SIZE ) {
370
+ size_t chunkSize = chunk + MAX_CHUNK_SIZE < size ? MAX_CHUNK_SIZE : size - chunk;
371
371
writeSocket (s, current, chunkSize);
372
372
current += chunkSize;
373
373
}
374
374
}
375
375
376
376
void NnNetwork::read (NnSize socketIndex, void *data, size_t size) {
377
377
assert (socketIndex >= 0 && socketIndex < nSockets);
378
- recvBytes += size;
378
+ recvBytes. fetch_add ( size) ;
379
379
380
380
char *current = (char *)data;
381
381
int s = sockets[socketIndex];
382
- for (size_t chunk = 0 ; chunk < size; chunk += ONE_MB ) {
383
- size_t chunkSize = chunk + ONE_MB < size ? ONE_MB : size - chunk;
382
+ for (size_t chunk = 0 ; chunk < size; chunk += MAX_CHUNK_SIZE ) {
383
+ size_t chunkSize = chunk + MAX_CHUNK_SIZE < size ? MAX_CHUNK_SIZE : size - chunk;
384
384
readSocket (s, current, chunkSize);
385
385
current += chunkSize;
386
386
}
@@ -399,7 +399,7 @@ void NnNetwork::readAck(NnSize socketIndex) {
399
399
bool NnNetwork::tryReadWithMaxAttempts (NnSize socketIndex, void *data, size_t size, unsigned long maxAttempts) {
400
400
assert (socketIndex >= 0 && socketIndex < nSockets);
401
401
if (tryReadSocket (sockets[socketIndex], data, size, maxAttempts)) {
402
- recvBytes += size;
402
+ recvBytes. fetch_add ( size) ;
403
403
return true ;
404
404
}
405
405
return false ;
@@ -420,7 +420,8 @@ void NnNetwork::writeMany(NnSize n, NnSocketIo *ios) {
420
420
if (io->size > 0 ) {
421
421
isWriting = true ;
422
422
int socket = sockets[io->socketIndex ];
423
- ssize_t s = send (socket, (const char *)io->data , io->size , 0 );
423
+ ssize_t chunkSize = io->size > MAX_CHUNK_SIZE ? MAX_CHUNK_SIZE : io->size ;
424
+ ssize_t s = send (socket, (const char *)io->data , chunkSize, 0 );
424
425
if (s < 0 ) {
425
426
if (isEagainError ()) {
426
427
continue ;
@@ -434,7 +435,7 @@ void NnNetwork::writeMany(NnSize n, NnSocketIo *ios) {
434
435
}
435
436
}
436
437
} while (isWriting);
437
- sentBytes += nBytes;
438
+ sentBytes. fetch_add ( nBytes) ;
438
439
}
439
440
440
441
void NnNetwork::writeAll (void *data, size_t size) {
@@ -477,18 +478,18 @@ void NnNetwork::readMany(NnSize n, NnSocketIo *ios) {
477
478
}
478
479
}
479
480
} while (isReading);
480
- recvBytes += nBytes;
481
+ recvBytes. fetch_add ( nBytes) ;
481
482
}
482
483
483
484
void NnNetwork::getStats (size_t *sentBytes, size_t *recvBytes) {
484
- *sentBytes = this ->sentBytes ;
485
- *recvBytes = this ->recvBytes ;
486
- this -> resetStats ();
485
+ *sentBytes = this ->sentBytes . load () ;
486
+ *recvBytes = this ->recvBytes . load () ;
487
+ resetStats ();
487
488
}
488
489
489
490
void NnNetwork::resetStats () {
490
- this -> sentBytes .exchange (0 );
491
- this -> recvBytes .exchange (0 );
491
+ sentBytes.exchange (0 );
492
+ recvBytes.exchange (0 );
492
493
}
493
494
494
495
static void syncWithRoot (NnNetwork *network, NnByte nodeIndex, NnByte *buffer, NnSize nBytes, NnSize nThreads, NnSize threadIndex) {
@@ -525,8 +526,7 @@ static void syncNodeSlices(bool onlyFromWorkerToRoot, NnNetwork *network, NnSize
525
526
if (nSocketsPerThread == 0 ) return ;
526
527
NnSize sliceBytes = nBytes / nNodes;
527
528
528
- std::unique_ptr<NnSocketIo> iosPtr (new NnSocketIo[nSocketsPerThread]);
529
- NnSocketIo *ios = iosPtr.get ();
529
+ std::vector<NnSocketIo> ios (nSocketsPerThread);
530
530
531
531
if (!onlyFromWorkerToRoot || isWorker) {
532
532
NnByte *mySliceData = &buffer[sliceBytes * nodeIndex];
@@ -537,7 +537,7 @@ static void syncNodeSlices(bool onlyFromWorkerToRoot, NnNetwork *network, NnSize
537
537
ios[i].data = mySliceData;
538
538
ios[i].size = sliceBytes;
539
539
}
540
- network->writeMany (nSocketsPerThread, ios);
540
+ network->writeMany (nSocketsPerThread, & ios[ 0 ] );
541
541
}
542
542
543
543
if (!onlyFromWorkerToRoot || !isWorker) {
@@ -549,7 +549,7 @@ static void syncNodeSlices(bool onlyFromWorkerToRoot, NnNetwork *network, NnSize
549
549
ios[i].data = sliceData;
550
550
ios[i].size = sliceBytes;
551
551
}
552
- network->readMany (nSocketsPerThread, ios);
552
+ network->readMany (nSocketsPerThread, & ios[ 0 ] );
553
553
}
554
554
}
555
555
0 commit comments