3
3
#include < type_traits>
4
4
#include < mpi.h>
5
5
6
- #include " libdistributed_stop_token .h"
6
+ #include " libdistributed_task_manager .h"
7
7
8
8
namespace distributed {
9
9
namespace queue {
@@ -41,16 +41,21 @@ constexpr int ROOT = 0;
41
41
enum class worker_status : int {
42
42
done = 1 ,
43
43
more = 2 ,
44
- cancel = 3
44
+ cancel = 3 ,
45
+ new_task = 4
45
46
};
46
47
47
48
48
- class WorkerStopToken : public StopToken
49
+ template <class RequestType , class ResponseType >
50
+ class WorkerTaskManager : public TaskManager <RequestType>
49
51
{
50
52
public:
51
- WorkerStopToken (MPI_Comm comm, MPI_Request request)
52
- : comm(comm)
53
+ WorkerTaskManager (MPI_Comm comm, MPI_Request request, MPI_Datatype request_dtype, MPI_Datatype response_dtype)
54
+ : TaskManager<RequestType>()
55
+ , comm(comm)
53
56
, stop_request(request)
57
+ , request_type(request_dtype)
58
+ , response_type(response_dtype)
54
59
, flag(0 )
55
60
{}
56
61
@@ -70,20 +75,42 @@ class WorkerStopToken : public StopToken
70
75
MPI_Wait (&request, MPI_STATUS_IGNORE);
71
76
}
72
77
78
+ void push (RequestType const & request) override {
79
+ ResponseType response;
80
+ MPI_Request mpi_request;
81
+ // let master know a new task is coming
82
+ MPI_Isend (&response, 1 , response_type, 0 , (int )worker_status::new_task, comm, &mpi_request);
83
+ MPI_Wait (&mpi_request, MPI_STATUS_IGNORE);
84
+
85
+ // send the new request to the master
86
+ MPI_Isend (&request, 1 , request_type, 0 , (int )worker_status::new_task, comm, &mpi_request);
87
+ MPI_Wait (&mpi_request, MPI_STATUS_IGNORE);
88
+
89
+ }
90
+
73
91
private:
74
92
MPI_Comm comm;
75
93
MPI_Request stop_request;
94
+ MPI_Datatype request_type;
95
+ MPI_Datatype response_type;
76
96
int flag;
77
97
};
78
98
79
- class MasterStopToken : public StopToken
99
+ template <class RequestType >
100
+ class MasterTaskManager : public TaskManager <RequestType>
80
101
{
81
102
public:
82
- MasterStopToken (MPI_Comm comm)
83
- : StopToken(),
103
+ template <class TaskIt >
104
+ MasterTaskManager (MPI_Comm comm, TaskIt begin, TaskIt end)
105
+ : TaskManager<RequestType>(),
84
106
comm (comm),
85
107
is_stop_requested (0 )
86
- {}
108
+ {
109
+ while (begin != end) {
110
+ requests.emplace (*begin);
111
+ ++begin;
112
+ }
113
+ }
87
114
88
115
bool stop_requested () override {
89
116
return is_stop_requested == 1 ;
@@ -96,9 +123,30 @@ class MasterStopToken : public StopToken
96
123
MPI_Wait (&request, MPI_STATUS_IGNORE);
97
124
}
98
125
126
+ void push (RequestType const & request) override {
127
+ requests.emplace (request);
128
+ }
129
+
130
+ void pop () {
131
+ requests.pop ();
132
+ }
133
+
134
+ RequestType const & front () const {
135
+ return requests.front ();
136
+ }
137
+
138
+ RequestType& front () {
139
+ return requests.front ();
140
+ }
141
+
142
+ bool empty () const {
143
+ return requests.empty ();
144
+ }
145
+
99
146
private:
100
147
MPI_Comm comm;
101
148
int is_stop_requested;
149
+ std::queue<RequestType> requests;
102
150
};
103
151
104
152
template <class RequestType , class ResponseType , class TaskForwardIt , class Function >
@@ -112,18 +160,20 @@ void master(MPI_Comm comm, MPI_Datatype request_dtype, MPI_Datatype response_dty
112
160
workers.push (i);
113
161
}
114
162
115
- MasterStopToken stop_token (comm);
163
+ // create task queue
164
+
165
+ MasterTaskManager<RequestType> task_manager (comm, tasks_begin, tasks_end);
116
166
117
167
int outstanding = 0 ;
118
- while ((tasks_begin != tasks_end and !stop_token .stop_requested ()) or outstanding > 0 ) {
168
+ while ((!task_manager. empty () and !task_manager .stop_requested ()) or outstanding > 0 ) {
119
169
120
- while (tasks_begin != tasks_end and !stop_token .stop_requested () and !workers.empty ()) {
170
+ while (!task_manager. empty () and !task_manager .stop_requested () and !workers.empty ()) {
121
171
int worker_id = workers.front ();
122
172
++outstanding;
123
173
workers.pop ();
124
174
125
- RequestType request = *tasks_begin ;
126
- ++tasks_begin ;
175
+ RequestType request = std::move (task_manager. front ()) ;
176
+ task_manager. pop () ;
127
177
128
178
MPI_Request mpi_request;
129
179
MPI_Isend (&request, 1 , request_dtype, worker_id, (int )worker_status::more, comm, &mpi_request);
@@ -138,19 +188,26 @@ void master(MPI_Comm comm, MPI_Datatype request_dtype, MPI_Datatype response_dty
138
188
MPI_Wait (&mpi_response, &response_status);
139
189
switch (worker_status (response_status.MPI_TAG )) {
140
190
case worker_status::more:
141
- maybe_stop_token (master_fn, response, stop_token );
191
+ maybe_stop_token (master_fn, response, task_manager );
142
192
break ;
143
193
case worker_status::done:
144
194
workers.push (response_status.MPI_SOURCE );
145
195
outstanding--;
146
196
break ;
147
197
case worker_status::cancel:
148
- stop_token.request_stop ();
198
+ task_manager.request_stop ();
199
+ break ;
200
+ case worker_status::new_task:
201
+ RequestType request;
202
+ MPI_Request mpi_request;
203
+ MPI_Irecv (&request, 1 , request_dtype, response_status.MPI_SOURCE , (int )worker_status::new_task, comm, &mpi_request);
204
+ MPI_Wait (&mpi_request, MPI_STATUS_IGNORE);
205
+ task_manager.push (request);
149
206
break ;
150
207
}
151
208
}
152
209
153
- if (not stop_token .stop_requested ()) stop_token .request_stop ();
210
+ if (not task_manager .stop_requested ()) task_manager .request_stop ();
154
211
155
212
while (not workers.empty ()) {
156
213
int worker_id = workers.front ();
@@ -198,38 +255,38 @@ worker_send(MPI_Comm comm, MPI_Datatype response_dtype, ValueType value)
198
255
MPI_Wait (&request, MPI_STATUS_IGNORE);
199
256
}
200
257
201
- template <typename Function, class Message ,
258
+ template <typename Function, class Message , class RequestType ,
202
259
typename = void >
203
260
struct takes_stop_token : std::false_type
204
261
{};
205
262
206
- template <typename Function, class Message >
263
+ template <typename Function, class Message , class RequestType >
207
264
struct takes_stop_token <
208
- Function, Message,
265
+ Function, Message, RequestType,
209
266
std::void_t <decltype(std::declval<Function>()(
210
- std::declval<Message>(), std::declval<StopToken &>()))>> : std::true_type
267
+ std::declval<Message>(), std::declval<TaskManager<RequestType> &>()))>> : std::true_type
211
268
{};
212
269
213
- template <class Function , class Message , class Enable = void >
270
+ template <class Function , class Message , class RequestType , class Enable = void >
214
271
struct maybe_stop_token_impl {
215
- static auto call (Function f, Message m, StopToken &) {
272
+ static auto call (Function f, Message m, TaskManager<RequestType> &) {
216
273
return f (m);
217
274
}
218
275
};
219
276
220
277
221
- template <class Function , class Message >
222
- struct maybe_stop_token_impl <Function, Message,
223
- typename std::enable_if_t <takes_stop_token<Function,Message>::value>> {
224
- static auto call (Function f, Message m, StopToken & s) {
278
+ template <class Function , class Message , class RequestType >
279
+ struct maybe_stop_token_impl <Function, Message, RequestType,
280
+ typename std::enable_if_t <takes_stop_token<Function,Message, RequestType >::value>> {
281
+ static auto call (Function f, Message m, TaskManager<RequestType> & s) {
225
282
return f (m,s);
226
283
}
227
284
};
228
285
229
- template <class Function , class Message >
230
- auto maybe_stop_token (Function f, Message m, StopToken & s)
286
+ template <class Function , class Message , class RequestType >
287
+ auto maybe_stop_token (Function f, Message m, TaskManager<RequestType> & s)
231
288
{
232
- return maybe_stop_token_impl<Function, Message>::call (f, m, s);
289
+ return maybe_stop_token_impl<Function, Message, RequestType >::call (f, m, s);
233
290
}
234
291
235
292
template <class RequestType , class ResponseType , class Function >
@@ -241,7 +298,7 @@ void worker(MPI_Comm comm, MPI_Datatype request_dtype, MPI_Datatype response_dty
241
298
int done = 0 ;
242
299
MPI_Request stop_request;
243
300
MPI_Ibcast (&done, 1 , MPI_INT, ROOT, comm, &stop_request);
244
- WorkerStopToken stop_token (comm, stop_request);
301
+ WorkerTaskManager<RequestType, ResponseType> stop_token (comm, stop_request, request_dtype, response_dtype );
245
302
246
303
bool worker_done = false ;
247
304
while (!worker_done) {
0 commit comments