Skip to content

Commit cb3a172

Browse files
Weiming Zhaoweimingzha0
Weiming Zhao
authored andcommitted
[ODLA][Async] Add status as callback argument
1 parent ed676a2 commit cb3a172

File tree

4 files changed

+6
-4
lines changed

4 files changed

+6
-4
lines changed

ODLA/platforms/odla_popart/odla_compute.cc

+2-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,8 @@ odla_status odla_SetContextItem(odla_context context, odla_item_type type,
8585
}
8686
switch (type) {
8787
case ODLA_ASYNC_CALLBACK_FUNC:
88-
context->async_callback_func = reinterpret_cast<void (*)(void*)>(value);
88+
context->async_callback_func =
89+
reinterpret_cast<void (*)(void*, odla_status)>(value);
8990
break;
9091
case ODLA_ASYNC_CALLBACK_ARG:
9192
context->async_callback_arg = reinterpret_cast<void*>(value);

ODLA/platforms/odla_popart/odla_pipeline.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,8 @@ struct _odla_pipeline_async_context : public _odla_pipeline_context {
216216
this);
217217
throw std::invalid_argument("async_callback_arg is null");
218218
}
219-
async_callback_func(async_callback_arg);
219+
async_callback_func(async_callback_arg,
220+
ODLA_SUCCESS); // FIXME: notify the status
220221
}
221222
bool hold(const std::string& function_name) override { return true; }
222223
};

ODLA/platforms/odla_popart/odla_popart.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ struct _odla_context {
176176
odla_computation comp;
177177
std::map<popart::TensorId, std::unique_ptr<popart::IArray>> inputs;
178178
std::map<popart::TensorId, std::unique_ptr<popart::IArray>> outputs;
179-
void (*async_callback_func)(void*) = nullptr;
179+
void (*async_callback_func)(void*, odla_status) = nullptr;
180180
void* async_callback_arg = nullptr;
181181
_odla_context(odla_computation c) : comp(c) {}
182182
std::thread::id thread_id_of_holder;

lib/target/generic_cpp/generic_cxx_codegen.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -626,7 +626,7 @@ void GenericCXXCodeGen::RunOnFunction(Function& function) {
626626
oss << "extern \"C\" {\n";
627627
}
628628
if (opts_.emit_code_for_async) {
629-
oss << " typedef int (*model_run_callback)(void *);\n";
629+
oss << " typedef int (*model_run_callback)(void *, odla_status);\n";
630630
}
631631
oss << " " << func_decl << ";\n";
632632
oss << "int " << init_func_name << "();\n";

0 commit comments

Comments
 (0)