Skip to content

Commit 6e06f70

Browse files
yanwei-gryanweijackzipu
authored
cherry-pick qtc optimization updates by jackz (#723)
* update to poplar 2.3.1 & remove useless code * update log message & add open log option * Change wait to use condition variable, and change the contextqueues to use mutex make the format of codes to meet requirements Correct the == to = for set the buffer_ element to nullptr Co-authored-by: yanwei <yw01041751@alibaba-inc.com> Co-authored-by: gcuser <jackz@graphcore.ai>
1 parent 97f5a65 commit 6e06f70

File tree

2 files changed

+104
-99
lines changed

2 files changed

+104
-99
lines changed

ODLA/platforms/odla_popart/odla_pipeline.cc

+85-72
Original file line numberDiff line numberDiff line change
@@ -100,87 +100,101 @@ void QManager::deleteQ() {
100100
}
101101
}
102102

103+
void ContextQueues::init(std::size_t capacity) {
104+
buffer_ = new odla_context[capacity];
105+
if (nullptr == buffer_)
106+
throw std::invalid_argument(
107+
"ContextQueues::init failed to create buffer for queue with capacity "
108+
": " +
109+
std::to_string(capacity));
110+
for (int i = 0; i < capacity; i++) buffer_[i] = nullptr;
111+
capacity_ = capacity;
112+
}
113+
103114
void ContextQueues::put(odla_context ctx) {
104115
popart::logging::info("ContextQueues::put -> ctx: {}.", ctx);
105116
{
106-
std::lock_guard<std::mutex> guard(write_mutex);
107-
write_queue->push(ctx);
108-
write_wait_queue->push(
109-
ctx); // put the ctx to input & wait_output queue in same order.
110-
}
117+
std::lock_guard<std::mutex> guard(queue_mutex_);
118+
auto new_tail = (tail_ + 1) % capacity_;
119+
if (new_tail == wait_) // last item as the boundary
120+
throw std::out_of_range("ContextQueues::put the queue is full");
121+
buffer_[tail_] = ctx;
122+
tail_ = new_tail;
123+
} // Make sure the queue mutex released here.
124+
// Notify the batch wait we got a batch data
125+
std::unique_lock<std::mutex> lock(batch_wait_mutex_);
126+
batch_wait_cv_.notify_one();
111127
}
112128

113129
odla_context ContextQueues::get_input_context() {
114-
if (nullptr != input_ctx) {
115-
return input_ctx;
116-
}
117-
if (!read_queue->empty())
118-
input_ctx = read_queue->front();
119-
else // read queue is empty, switch it
120-
{
121-
std::lock_guard<std::mutex> guard(write_mutex);
122-
std::queue<odla_context>* tmp = read_queue;
123-
read_queue = write_queue;
124-
write_queue = tmp;
125-
popart::logging::info(
126-
"switched the read write queue, now read queu size is: {}.",
127-
read_queue->size());
128-
if (!read_queue->empty())
129-
input_ctx = read_queue->front();
130-
else { // create a zero data if there's not data in the 2 queues
131-
input_ctx = create_empty_odla_context();
132-
write_wait_queue->push(
133-
input_ctx); // Make it wait for the return for the empty data
134-
}
135-
}
136-
137-
return input_ctx;
130+
throw std::runtime_error(
131+
"ContextQueues::get_input_context we should never call this.");
138132
}
139133

140134
odla_context ContextQueues::get_output_context() {
141-
if (output_ctx != nullptr) return output_ctx;
142-
if (!read_wait_queue->empty())
143-
output_ctx = read_wait_queue->front();
144-
else {
145-
// switch the wait queue
146-
std::lock_guard<std::mutex> guard(
147-
write_mutex); // Use the same mutex to save 1 mutex lock for every put
148-
std::queue<odla_context>* tmp = read_wait_queue;
149-
read_wait_queue = write_wait_queue;
150-
write_wait_queue = tmp;
151-
popart::logging::info(
152-
"switched the read write wait queue, now read queu size is: {}.",
153-
read_wait_queue->size());
154-
}
155-
if (!read_wait_queue->empty()) output_ctx = read_wait_queue->front();
156-
if (nullptr == output_ctx)
135+
if (wait_ == tail_)
157136
throw std::out_of_range(
158-
"*** FATAL ERROR *** No context in the queue when an output gotten");
159-
return output_ctx;
137+
"ContextQueues: queue is empty when get_output_context()");
138+
return buffer_[wait_];
160139
}
161140

162141
void ContextQueues::pop_input(odla_context ctx) {
163-
popart::logging::info("ContextQueues::pop_input with ctx: {}", input_ctx);
164-
if (!input_ctx->deletable()) // Only pop the non zero ctx, the zero one not in
165-
// the queue
166-
read_queue->pop();
167-
input_ctx = nullptr;
142+
popart::logging::info("ContextQueues::pop_input with ctx: {}", ctx);
143+
assert(ctx == buffer_[head_]);
144+
head_ = (head_ + 1) % capacity_;
168145
}
169146

170-
void ContextQueues::pop_output(
171-
odla_context
172-
ctx) { // Never delete a context here, only operate on the queue
173-
// wait_output_queue.pop();
174-
if (!read_wait_queue
175-
->empty()) // There must be an element when all tensor written
176-
read_wait_queue->pop(); // pop the first one from the read wait queue
177-
else {
178-
throw std::out_of_range(
179-
"*** FATAL ERROR *** no ctx in read_wait_queue when pop_output called");
180-
}
181-
output_ctx = nullptr;
147+
void ContextQueues::pop_output(odla_context ctx) {
148+
if (wait_ == head_)
149+
throw std::runtime_error("Got out before input all read on index " +
150+
std::to_string(wait_));
151+
assert(ctx == buffer_[wait_]);
152+
buffer_[wait_] = nullptr; // clear the buffer to nullptr;
153+
wait_ = (wait_ + 1) % capacity_;
182154
}
183155

156+
odla_context ContextQueues::get_ctx_by_tensor(const popart::TensorId& id) {
157+
std::uint32_t idx = -1;
158+
odla_context ctx = nullptr;
159+
// Get current index
160+
auto iter = tensor_to_idx_.find(id);
161+
if (tensor_to_idx_.end() == iter)
162+
idx = 0;
163+
else
164+
idx = iter->second;
165+
// Check whether is empty, tail alwasy points to the first element not written
166+
std::uint32_t cnt = 0;
167+
popart::logging::info("ContextQueues::get_ctx_by_tensor queue has size: {}",
168+
size());
169+
while (idx == tail_) {
170+
auto locked_tail = tail_;
171+
{
172+
std::lock_guard<std::mutex> guard(queue_mutex_);
173+
locked_tail = tail_;
174+
}
175+
if (idx == locked_tail) {
176+
std::unique_lock<std::mutex> lock(batch_wait_mutex_);
177+
batch_wait_cv_.wait_for(lock, std::chrono::milliseconds(5));
178+
}
179+
if (idx != tail_) break;
180+
popart::logging::info(
181+
"[get_ctx_by_tensor] the queue is empty when read, add zero contexts");
182+
if (cnt++ > 1)
183+
throw std::runtime_error(
184+
"[get_ctx_by_tensor] Must get one ctx in 2 fetch, as empty one "
185+
"created.");
186+
odla_context zero_ctx = create_empty_odla_context();
187+
put(zero_ctx);
188+
}
189+
// The lock ensured the ctx has been written
190+
ctx = buffer_[idx];
191+
popart::logging::info(
192+
"ContextQueues::get_ctx_by_tensor tensorid:{} got ctx:{} with idx: {}",
193+
id, ctx, idx);
194+
// Update the index of the tensor to next
195+
tensor_to_idx_[id] = (idx + 1) % capacity_;
196+
return ctx;
197+
}
184198
/*------------------------------------------------------------------------*/
185199
LockFreeQueue::LockFreeQueue() : head_(0), tail_(0), wait_(0) {}
186200

@@ -222,6 +236,9 @@ void LockFreeQueue::put(odla_context ctx) {
222236
if (cnt++ > 5)
223237
throw std::runtime_error("LockFreeQueue::put No one should stop me");
224238
}
239+
// Notify the batch wait we got a batch data
240+
std::unique_lock<std::mutex> lock(batch_wait_mutex_);
241+
batch_wait_cv_.notify_one();
225242
popart::logging::info(
226243
"[LockFreeQueue::put] Set the idx: {} for ctx: {} in {} times.", idx, ctx,
227244
cnt);
@@ -290,15 +307,11 @@ odla_context LockFreeQueue::get_ctx_by_tensor(const popart::TensorId& id) {
290307
popart::logging::info("LockFreeQueue::get_ctx_by_tensor queue has size: {}",
291308
size());
292309
while (idx == tail_.load()) {
293-
bool got_data = false;
294-
for (int i = 0; i < 5; i++) {
295-
std::this_thread::sleep_for(std::chrono::milliseconds(1));
296-
if (idx != tail_.load()) {
297-
got_data = true;
298-
break;
299-
}
310+
if (idx == tail_.load()) {
311+
std::unique_lock<std::mutex> lock(batch_wait_mutex_);
312+
batch_wait_cv_.wait_for(lock, std::chrono::milliseconds(5));
300313
}
301-
if (got_data) break;
314+
if (idx != tail_.load()) break;
302315
popart::logging::info(
303316
"[get_ctx_by_tensor] the queue is empty when read, add zero contexts");
304317
if (cnt++ > 1)

ODLA/platforms/odla_popart/odla_pipeline.h

+19-27
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
#include <atomic>
2525
#include <chrono>
26+
#include <condition_variable>
2627
#include <mutex>
2728
#include <popart/stepio.hpp>
2829
#include <queue>
@@ -48,40 +49,29 @@ class Queue {
4849

4950
class ContextQueues : public Queue {
5051
private:
51-
std::queue<odla_context> input_queue_1;
52-
std::queue<odla_context> input_queue_2;
53-
std::queue<odla_context> wait_output_queue_1;
54-
std::queue<odla_context> wait_output_queue_2;
55-
std::mutex write_mutex;
56-
std::queue<odla_context>* read_queue;
57-
std::queue<odla_context>* write_queue;
58-
std::queue<odla_context>* read_wait_queue;
59-
std::queue<odla_context>* write_wait_queue;
60-
odla_context input_ctx; // the context which is under reading
61-
odla_context output_ctx; // the context which is under writing
52+
odla_context* buffer_;
53+
std::size_t capacity_;
54+
std::uint32_t head_;
55+
std::uint32_t tail_;
56+
std::uint32_t wait_;
57+
std::map<popart::TensorId, std::uint32_t> tensor_to_idx_;
58+
std::mutex batch_wait_mutex_;
59+
std::condition_variable batch_wait_cv_;
60+
std::mutex queue_mutex_; // lock the read & write
6261

6362
public:
64-
ContextQueues()
65-
: read_queue(&input_queue_1),
66-
write_queue(&input_queue_2),
67-
read_wait_queue(&wait_output_queue_1),
68-
write_wait_queue(&wait_output_queue_2),
69-
input_ctx(nullptr),
70-
output_ctx(nullptr) {}
71-
72-
~ContextQueues() {}
73-
void init(std::size_t capacity) final {}
63+
ContextQueues() : head_(0), tail_(0), wait_(0){};
64+
~ContextQueues() {
65+
if (buffer_) delete[] buffer_;
66+
}
67+
void init(std::size_t capacity);
7468
void put(odla_context ctx) final;
7569
odla_context get_input_context() final;
76-
odla_context get_ctx_by_tensor(const popart::TensorId& id) final {
77-
return nullptr;
78-
}
70+
odla_context get_ctx_by_tensor(const popart::TensorId& id) final;
7971
odla_context get_output_context() final;
8072
void pop_input(odla_context ctx) final;
8173
void pop_output(odla_context ctx) final;
82-
std::size_t size() final {
83-
return input_queue_1.size() + input_queue_2.size();
84-
}
74+
std::size_t size() final { return (tail_ - wait_ + capacity_) % capacity_; }
8575
};
8676

8777
class LockFreeQueue : public Queue {
@@ -92,6 +82,8 @@ class LockFreeQueue : public Queue {
9282
std::atomic<uint32_t> tail_;
9383
std::uint32_t wait_;
9484
std::map<popart::TensorId, std::uint32_t> tensor_to_idx_;
85+
std::mutex batch_wait_mutex_;
86+
std::condition_variable batch_wait_cv_;
9587

9688
public:
9789
LockFreeQueue();

0 commit comments

Comments
 (0)