Skip to content

Commit 3353d56

Browse files
authored
feat: reduction of writeMany/readMany calls. (#118)
1 parent 668ea98 commit 3353d56

File tree

2 files changed

+25
-52
lines changed

2 files changed

+25
-52
lines changed

src/tasks.cpp

Lines changed: 25 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,15 @@ void syncUnitBuffer(unsigned int nThreads, unsigned int threadIndex, Transformer
4949
// root
5050

5151
unsigned int nSockets = ctx->socketPool->nSockets / nThreads + (ctx->socketPool->nSockets % nThreads > threadIndex ? 1 : 0);
52-
SocketIo ios[nSockets];
53-
for (int i = 0; i < nSockets; i++) {
54-
ios[i].socketIndex = threadIndex + i * nThreads;
55-
ios[i].data = buffer;
56-
ios[i].size = bufferBytes;
52+
if (nSockets > 0) {
53+
SocketIo ios[nSockets];
54+
for (int i = 0; i < nSockets; i++) {
55+
ios[i].socketIndex = threadIndex + i * nThreads;
56+
ios[i].data = buffer;
57+
ios[i].size = bufferBytes;
58+
}
59+
ctx->socketPool->writeMany(nSockets, ios);
5760
}
58-
ctx->socketPool->writeMany(nSockets, ios);
5961
} else if (ctx->socket != NULL) {
6062
if (threadIndex != 0) return;
6163

@@ -70,54 +72,24 @@ void syncSliceOfSlicedBuffer(unsigned int nThreads, unsigned int threadIndex, Tr
7072
// root
7173

7274
unsigned int nSockets = ctx->socketPool->nSockets / nThreads + (ctx->socketPool->nSockets % nThreads > threadIndex ? 1 : 0);
73-
SocketIo ios[nSockets];
74-
for (int i = 0; i < nSockets; i++) {
75-
int socketIndex = threadIndex + i * nThreads;
76-
uint8_t workerSliceIndex = socketIndex + 1;
77-
ios[i].socketIndex = socketIndex;
78-
ios[i].data = ctx->transformer->buffer->getSliced(bufferIndex, workerSliceIndex);
79-
ios[i].size = bufferBytes;
80-
}
81-
82-
ctx->socketPool->readMany(nSockets, ios);
83-
} else if (ctx->socket != NULL) {
84-
if (threadIndex != 0) return;
85-
86-
// worker
87-
void* buffer = ctx->transformer->buffer->getSliced(bufferIndex, ctx->transformer->sliceIndex);
88-
ctx->socket->write(buffer, bufferBytes);
89-
}
90-
}
91-
92-
void syncMissingSlicesOfSlicedBuffer(unsigned int nThreads, unsigned int threadIndex, TransformerContext* ctx, uint8_t bufferIndex) {
93-
size_t sliceBytes = ctx->transformer->buffer->getSlicedBytes(bufferIndex);
94-
if (ctx->socketPool != NULL) {
95-
// root
96-
97-
unsigned int nSockets = ctx->socketPool->nSockets / nThreads + (ctx->socketPool->nSockets % nThreads > threadIndex ? 1 : 0);
98-
SocketIo ios[nSockets];
99-
100-
for (uint8_t si = 0; si < ctx->transformer->spec->nSlices - 1; si++) {
101-
for (unsigned int i = 0; i < nSockets; i++) {
75+
if (nSockets > 0) {
76+
SocketIo ios[nSockets];
77+
for (int i = 0; i < nSockets; i++) {
10278
int socketIndex = threadIndex + i * nThreads;
10379
uint8_t workerSliceIndex = socketIndex + 1;
104-
slice_index_t sliceIndex = si < workerSliceIndex ? si : si + 1;
10580
ios[i].socketIndex = socketIndex;
106-
ios[i].data = ctx->transformer->buffer->getSliced(bufferIndex, sliceIndex);
107-
ios[i].size = sliceBytes;
81+
ios[i].data = ctx->transformer->buffer->getSliced(bufferIndex, workerSliceIndex);
82+
ios[i].size = bufferBytes;
10883
}
109-
ctx->socketPool->writeMany(nSockets, ios);
84+
85+
ctx->socketPool->readMany(nSockets, ios);
11086
}
11187
} else if (ctx->socket != NULL) {
11288
if (threadIndex != 0) return;
11389

11490
// worker
115-
for (slice_index_t sliceIndex = 0; sliceIndex < ctx->transformer->spec->nSlices; sliceIndex++) {
116-
if (sliceIndex != ctx->transformer->sliceIndex) {
117-
void* buffer = ctx->transformer->buffer->getSliced(bufferIndex, sliceIndex);
118-
ctx->socket->read(buffer, sliceBytes);
119-
}
120-
}
91+
void* buffer = ctx->transformer->buffer->getSliced(bufferIndex, ctx->transformer->sliceIndex);
92+
ctx->socket->write(buffer, bufferBytes);
12193
}
12294
}
12395

@@ -167,13 +139,15 @@ void sendPos(TASK_ARGS) {
167139

168140
if (ctx->socketPool != NULL) {
169141
unsigned int nSockets = ctx->socketPool->nSockets / nThreads + (ctx->socketPool->nSockets % nThreads > threadIndex ? 1 : 0);
170-
SocketIo ios[nSockets];
171-
for (int i = 0; i < nSockets; i++) {
172-
ios[i].socketIndex = threadIndex + i * nThreads;
173-
ios[i].data = &transformer->pos;
174-
ios[i].size = sizeof(pos_t);
142+
if (nSockets > 0) {
143+
SocketIo ios[nSockets];
144+
for (int i = 0; i < nSockets; i++) {
145+
ios[i].socketIndex = threadIndex + i * nThreads;
146+
ios[i].data = &transformer->pos;
147+
ios[i].size = sizeof(pos_t);
148+
}
149+
ctx->socketPool->writeMany(nSockets, ios);
175150
}
176-
ctx->socketPool->writeMany(nSockets, ios);
177151
}
178152
}
179153

src/tasks.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ class TransformerArch {
4444

4545
void syncUnitBuffer(unsigned int nThreads, unsigned int threadIndex, TransformerContext* ctx, uint8_t bufferIndex);
4646
void syncSliceOfSlicedBuffer(unsigned int nThreads, unsigned int threadIndex, TransformerContext* ctx, uint8_t bufferIndex);
47-
void syncMissingSlicesOfSlicedBuffer(unsigned int nThreads, unsigned int threadIndex, TransformerContext* ctx, uint8_t bufferIndex);
4847
void quantizeUnitBuffer(unsigned int nThreads, unsigned int threadIndex, TransformerContext* ctx, uint8_t sourceBufferIndex, uint8_t targetBufferIndex);
4948
void quantizeSlicedBuffer(unsigned int nThreads, unsigned int threadIndex, TransformerContext* ctx, bool quantizeRootSlice, uint8_t sourceBufferIndex, uint8_t targetBufferIndex);
5049
void dequantizeSlicedBuffer(unsigned int nThreads, unsigned int threadIndex, TransformerContext* ctx, bool dequantizeRootSlice, uint8_t sourceBufferIndex, uint8_t targetBufferIndex);

0 commit comments

Comments
 (0)