@@ -100,87 +100,101 @@ void QManager::deleteQ() {
100
100
}
101
101
}
102
102
103
+ void ContextQueues::init (std::size_t capacity) {
104
+ buffer_ = new odla_context[capacity];
105
+ if (nullptr == buffer_)
106
+ throw std::invalid_argument (
107
+ " ContextQueues::init failed to create buffer for queue with capacity "
108
+ " : " +
109
+ std::to_string (capacity));
110
+ for (int i = 0 ; i < capacity; i++) buffer_[i] = nullptr ;
111
+ capacity_ = capacity;
112
+ }
113
+
103
114
void ContextQueues::put (odla_context ctx) {
104
115
popart::logging::info (" ContextQueues::put -> ctx: {}." , ctx);
105
116
{
106
- std::lock_guard<std::mutex> guard (write_mutex);
107
- write_queue->push (ctx);
108
- write_wait_queue->push (
109
- ctx); // put the ctx to input & wait_output queue in same order.
110
- }
117
+ std::lock_guard<std::mutex> guard (queue_mutex_);
118
+ auto new_tail = (tail_ + 1 ) % capacity_;
119
+ if (new_tail == wait_) // last item as the boundary
120
+ throw std::out_of_range (" ContextQueues::put the queue is full" );
121
+ buffer_[tail_] = ctx;
122
+ tail_ = new_tail;
123
+ } // Make sure the queue mutex released here.
124
+ // Notify the batch wait we got a batch data
125
+ std::unique_lock<std::mutex> lock (batch_wait_mutex_);
126
+ batch_wait_cv_.notify_one ();
111
127
}
112
128
113
129
odla_context ContextQueues::get_input_context () {
114
- if (nullptr != input_ctx) {
115
- return input_ctx;
116
- }
117
- if (!read_queue->empty ())
118
- input_ctx = read_queue->front ();
119
- else // read queue is empty, switch it
120
- {
121
- std::lock_guard<std::mutex> guard (write_mutex);
122
- std::queue<odla_context>* tmp = read_queue;
123
- read_queue = write_queue;
124
- write_queue = tmp;
125
- popart::logging::info (
126
- " switched the read write queue, now read queu size is: {}." ,
127
- read_queue->size ());
128
- if (!read_queue->empty ())
129
- input_ctx = read_queue->front ();
130
- else { // create a zero data if there's not data in the 2 queues
131
- input_ctx = create_empty_odla_context ();
132
- write_wait_queue->push (
133
- input_ctx); // Make it wait for the return for the empty data
134
- }
135
- }
136
-
137
- return input_ctx;
130
+ throw std::runtime_error (
131
+ " ContextQueues::get_input_context we should never call this." );
138
132
}
139
133
140
134
odla_context ContextQueues::get_output_context () {
141
- if (output_ctx != nullptr ) return output_ctx;
142
- if (!read_wait_queue->empty ())
143
- output_ctx = read_wait_queue->front ();
144
- else {
145
- // switch the wait queue
146
- std::lock_guard<std::mutex> guard (
147
- write_mutex); // Use the same mutex to save 1 mutex lock for every put
148
- std::queue<odla_context>* tmp = read_wait_queue;
149
- read_wait_queue = write_wait_queue;
150
- write_wait_queue = tmp;
151
- popart::logging::info (
152
- " switched the read write wait queue, now read queu size is: {}." ,
153
- read_wait_queue->size ());
154
- }
155
- if (!read_wait_queue->empty ()) output_ctx = read_wait_queue->front ();
156
- if (nullptr == output_ctx)
135
+ if (wait_ == tail_)
157
136
throw std::out_of_range (
158
- " *** FATAL ERROR *** No context in the queue when an output gotten " );
159
- return output_ctx ;
137
+ " ContextQueues: queue is empty when get_output_context() " );
138
+ return buffer_[wait_] ;
160
139
}
161
140
162
141
void ContextQueues::pop_input (odla_context ctx) {
163
- popart::logging::info (" ContextQueues::pop_input with ctx: {}" , input_ctx);
164
- if (!input_ctx->deletable ()) // Only pop the non zero ctx, the zero one not in
165
- // the queue
166
- read_queue->pop ();
167
- input_ctx = nullptr ;
142
+ popart::logging::info (" ContextQueues::pop_input with ctx: {}" , ctx);
143
+ assert (ctx == buffer_[head_]);
144
+ head_ = (head_ + 1 ) % capacity_;
168
145
}
169
146
170
- void ContextQueues::pop_output (
171
- odla_context
172
- ctx) { // Never delete a context here, only operate on the queue
173
- // wait_output_queue.pop();
174
- if (!read_wait_queue
175
- ->empty ()) // There must be an element when all tensor written
176
- read_wait_queue->pop (); // pop the first one from the read wait queue
177
- else {
178
- throw std::out_of_range (
179
- " *** FATAL ERROR *** no ctx in read_wait_queue when pop_output called" );
180
- }
181
- output_ctx = nullptr ;
147
+ void ContextQueues::pop_output (odla_context ctx) {
148
+ if (wait_ == head_)
149
+ throw std::runtime_error (" Got out before input all read on index " +
150
+ std::to_string (wait_));
151
+ assert (ctx == buffer_[wait_]);
152
+ buffer_[wait_] = nullptr ; // clear the buffer to nullptr;
153
+ wait_ = (wait_ + 1 ) % capacity_;
182
154
}
183
155
156
+ odla_context ContextQueues::get_ctx_by_tensor (const popart::TensorId& id) {
157
+ std::uint32_t idx = -1 ;
158
+ odla_context ctx = nullptr ;
159
+ // Get current index
160
+ auto iter = tensor_to_idx_.find (id);
161
+ if (tensor_to_idx_.end () == iter)
162
+ idx = 0 ;
163
+ else
164
+ idx = iter->second ;
165
+ // Check whether is empty, tail alwasy points to the first element not written
166
+ std::uint32_t cnt = 0 ;
167
+ popart::logging::info (" ContextQueues::get_ctx_by_tensor queue has size: {}" ,
168
+ size ());
169
+ while (idx == tail_) {
170
+ auto locked_tail = tail_;
171
+ {
172
+ std::lock_guard<std::mutex> guard (queue_mutex_);
173
+ locked_tail = tail_;
174
+ }
175
+ if (idx == locked_tail) {
176
+ std::unique_lock<std::mutex> lock (batch_wait_mutex_);
177
+ batch_wait_cv_.wait_for (lock, std::chrono::milliseconds (5 ));
178
+ }
179
+ if (idx != tail_) break ;
180
+ popart::logging::info (
181
+ " [get_ctx_by_tensor] the queue is empty when read, add zero contexts" );
182
+ if (cnt++ > 1 )
183
+ throw std::runtime_error (
184
+ " [get_ctx_by_tensor] Must get one ctx in 2 fetch, as empty one "
185
+ " created." );
186
+ odla_context zero_ctx = create_empty_odla_context ();
187
+ put (zero_ctx);
188
+ }
189
+ // The lock ensured the ctx has been written
190
+ ctx = buffer_[idx];
191
+ popart::logging::info (
192
+ " ContextQueues::get_ctx_by_tensor tensorid:{} got ctx:{} with idx: {}" ,
193
+ id, ctx, idx);
194
+ // Update the index of the tensor to next
195
+ tensor_to_idx_[id] = (idx + 1 ) % capacity_;
196
+ return ctx;
197
+ }
184
198
/* ------------------------------------------------------------------------*/
185
199
LockFreeQueue::LockFreeQueue () : head_(0 ), tail_(0 ), wait_(0 ) {}
186
200
@@ -222,6 +236,9 @@ void LockFreeQueue::put(odla_context ctx) {
222
236
if (cnt++ > 5 )
223
237
throw std::runtime_error (" LockFreeQueue::put No one should stop me" );
224
238
}
239
+ // Notify the batch wait we got a batch data
240
+ std::unique_lock<std::mutex> lock (batch_wait_mutex_);
241
+ batch_wait_cv_.notify_one ();
225
242
popart::logging::info (
226
243
" [LockFreeQueue::put] Set the idx: {} for ctx: {} in {} times." , idx, ctx,
227
244
cnt);
@@ -290,15 +307,11 @@ odla_context LockFreeQueue::get_ctx_by_tensor(const popart::TensorId& id) {
290
307
popart::logging::info (" LockFreeQueue::get_ctx_by_tensor queue has size: {}" ,
291
308
size ());
292
309
while (idx == tail_.load ()) {
293
- bool got_data = false ;
294
- for (int i = 0 ; i < 5 ; i++) {
295
- std::this_thread::sleep_for (std::chrono::milliseconds (1 ));
296
- if (idx != tail_.load ()) {
297
- got_data = true ;
298
- break ;
299
- }
310
+ if (idx == tail_.load ()) {
311
+ std::unique_lock<std::mutex> lock (batch_wait_mutex_);
312
+ batch_wait_cv_.wait_for (lock, std::chrono::milliseconds (5 ));
300
313
}
301
- if (got_data ) break ;
314
+ if (idx != tail_. load () ) break ;
302
315
popart::logging::info (
303
316
" [get_ctx_by_tensor] the queue is empty when read, add zero contexts" );
304
317
if (cnt++ > 1 )
0 commit comments