Skip to content

Commit a4e0743

Browse files
jackzipuweimingzha0
authored andcommitted
callback function definition changes, and handle the error raised by poplar SDK
1 parent cb3a172 commit a4e0743

File tree

4 files changed

+29
-5
lines changed

4 files changed

+29
-5
lines changed

ODLA/platforms/odla_popart/odla_compute.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ odla_status odla_SetContextItem(odla_context context, odla_item_type type,
8686
switch (type) {
8787
case ODLA_ASYNC_CALLBACK_FUNC:
8888
context->async_callback_func =
89-
reinterpret_cast<void (*)(void*, odla_status)>(value);
89+
reinterpret_cast<int (*)(void*, odla_status)>(value);
9090
break;
9191
case ODLA_ASYNC_CALLBACK_ARG:
9292
context->async_callback_arg = reinterpret_cast<void*>(value);

ODLA/platforms/odla_popart/odla_pipeline.cc

+19
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,16 @@ odla_context ContextQueues::get_ctx_by_tensor(const popart::TensorId& id) {
195195
tensor_to_idx_[id] = (idx + 1) % capacity_;
196196
return ctx;
197197
}
198+
199+
void ContextQueues::handle_error() {
200+
auto idx = wait_;
201+
while (idx != tail_) {
202+
odla_context ctx = buffer_[idx];
203+
if (ctx != nullptr) ctx->notify();
204+
idx = (idx + 1) % capacity_;
205+
}
206+
}
207+
198208
/*------------------------------------------------------------------------*/
199209
LockFreeQueue::LockFreeQueue() : head_(0), tail_(0), wait_(0) {}
200210

@@ -344,6 +354,15 @@ odla_context LockFreeQueue::get_ctx_by_tensor(const popart::TensorId& id) {
344354
return ctx;
345355
}
346356

357+
void LockFreeQueue::handle_error() {
358+
auto idx = wait_;
359+
while (idx != tail_.load()) {
360+
odla_context ctx = buffer_[idx].load();
361+
if (ctx != nullptr) ctx->notify();
362+
idx = (idx + 1) % capacity_;
363+
}
364+
}
365+
347366
/*======================================== step io callbacks
348367
* =========================================*/
349368
popart::StepIOCallback::InputCallback input_callback =

ODLA/platforms/odla_popart/odla_pipeline.h

+8-3
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class Queue {
4545
virtual void pop_input(odla_context ctx) = 0;
4646
virtual void pop_output(odla_context ctx) = 0;
4747
virtual std::size_t size() = 0;
48+
virtual void handle_error() = 0;
4849
};
4950

5051
class ContextQueues : public Queue {
@@ -72,6 +73,7 @@ class ContextQueues : public Queue {
7273
void pop_input(odla_context ctx) final;
7374
void pop_output(odla_context ctx) final;
7475
std::size_t size() final { return (tail_ - wait_ + capacity_) % capacity_; }
76+
void handle_error() final;
7577
};
7678

7779
class LockFreeQueue : public Queue {
@@ -100,6 +102,7 @@ class LockFreeQueue : public Queue {
100102
std::size_t size() final {
101103
return (tail_.load() - wait_ + capacity_) % capacity_;
102104
}
105+
void handle_error() final;
103106
};
104107

105108
class QManager {
@@ -115,7 +118,10 @@ class QManager {
115118
void createQ(std::string queueType);
116119
void deleteQ();
117120
inline Queue* getQ() { return queue_; }
118-
inline void set_status(odla_status status) { status_ = status; }
121+
inline void set_status(odla_status status) {
122+
status_ = status;
123+
if (ODLA_SUCCESS != status_ && queue_) queue_->handle_error();
124+
}
119125
inline odla_status get_status() { return status_; }
120126
static inline QManager* instance() { return instance_; }
121127
};
@@ -216,8 +222,7 @@ struct _odla_pipeline_async_context : public _odla_pipeline_context {
216222
this);
217223
throw std::invalid_argument("async_callback_arg is null");
218224
}
219-
async_callback_func(async_callback_arg,
220-
ODLA_SUCCESS); // FIXME: notify the status
225+
async_callback_func(async_callback_arg, QManager::instance()->get_status());
221226
}
222227
bool hold(const std::string& function_name) override { return true; }
223228
};

ODLA/platforms/odla_popart/odla_popart.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ struct _odla_context {
176176
odla_computation comp;
177177
std::map<popart::TensorId, std::unique_ptr<popart::IArray>> inputs;
178178
std::map<popart::TensorId, std::unique_ptr<popart::IArray>> outputs;
179-
void (*async_callback_func)(void*, odla_status) = nullptr;
179+
int (*async_callback_func)(void*, odla_status) = nullptr;
180180
void* async_callback_arg = nullptr;
181181
_odla_context(odla_computation c) : comp(c) {}
182182
std::thread::id thread_id_of_holder;

0 commit comments

Comments
 (0)