Skip to content

Commit 9a6a683

Browse files
Weiming Zhaoweimingzha0
Weiming Zhao
authored andcommitted
[CodeGen] Check odla status after odla API calls
1 parent 78d22f6 commit 9a6a683

File tree

3 files changed

+44
-29
lines changed

3 files changed

+44
-29
lines changed

lib/target/generic_cpp/generic_cxx_codegen.cc

+42-27
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ std::string GenericCXXCodeGen::GetFunctionDecl(const Function& func,
208208
bool with_type,
209209
bool public_function) {
210210
const static std::string inference_func_decl =
211-
"void model_run(int num_inputs, const void* inputs[],"
211+
"int model_run(int num_inputs, const void* inputs[],"
212212
"int num_outputs, void* outputs[], int batch_size)";
213213
if (opts_.emit_inference_func_sig && func.IsEntryFunction() &&
214214
public_function) {
@@ -221,7 +221,7 @@ std::string GenericCXXCodeGen::GetFunctionDecl(const Function& func,
221221
ss << "static ";
222222
}
223223
if (with_func_name) {
224-
ss << "void " << NormalizeVariableName(func.GetName());
224+
ss << "int " << NormalizeVariableName(func.GetName());
225225
}
226226
ss << "(";
227227
if (is_sub) {
@@ -612,10 +612,13 @@ void GenericCXXCodeGen::RunOnFunction(Function& function) {
612612
oss << "extern \"C\" {\n";
613613
}
614614
oss << " " << func_decl << ";\n";
615-
oss << "void " << init_func_name << "();\n";
616-
oss << "void " << fini_func_name << "();\n";
617-
oss << (is_compile_mode ? "odla_computation " : "static void ")
618-
<< helper_func_name << "();\n";
615+
oss << "int " << init_func_name << "();\n";
616+
oss << "int " << fini_func_name << "();\n";
617+
if (is_compile_mode) {
618+
oss << "int " << helper_func_name << "(odla_computation comp);\n";
619+
} else {
620+
oss << "static void " << helper_func_name << "()\n;";
621+
}
619622
if (opts_.dialect == Dialect::CXX_11) {
620623
oss << "};\n";
621624
}
@@ -626,9 +629,7 @@ void GenericCXXCodeGen::RunOnFunction(Function& function) {
626629
if (emit_builder_func) {
627630
if (is_compile_mode) {
628631
os_ << "static odla_computation Comp;\n";
629-
os_ << "odla_computation " << helper_func_name << "() {\n";
630-
os_ << " odla_computation comp;\n";
631-
os_ << " odla_CreateComputation(&comp);\n";
632+
os_ << "int " << helper_func_name << "(odla_computation comp) {\n";
632633
EmitComputationItems(&os_, opts_);
633634
} else {
634635
os_ << "static void " << helper_func_name
@@ -641,11 +642,6 @@ void GenericCXXCodeGen::RunOnFunction(Function& function) {
641642
os_ << " odla_SetCurrentDevice(device);";
642643
}
643644

644-
if (is_compile_mode) {
645-
os_ << " static odla_computation comp;\n";
646-
os_ << " if (comp == " << EmitNull() << ") {\n";
647-
os_ << " odla_CreateComputation(&comp);\n";
648-
}
649645
EmitComputationItems(&os_, opts_);
650646
}
651647

@@ -675,7 +671,7 @@ void GenericCXXCodeGen::RunOnFunction(Function& function) {
675671
RunOnBasicBlock(*bb);
676672
}
677673
if (is_compile_mode) {
678-
os_ << " return comp;\n";
674+
os_ << " return ODLA_SUCCESS;\n";
679675
}
680676

681677
os_ << "}\n"; // End of computation build function.
@@ -684,28 +680,41 @@ void GenericCXXCodeGen::RunOnFunction(Function& function) {
684680
dynamic_check_os_ << GenerateTestFunc(function, func_decl, *return_inst);
685681
}
686682

683+
const std::string& status_check{
684+
"if (status != ODLA_SUCCESS) { return status;}"};
685+
687686
if (emit_builder_func) {
688687
// Emit function for launching computation.
689688
if (opts_.exec_mode == ExecMode::Compile) {
690689
if (function.IsEntryFunction()) {
691-
os_ << "void " << fini_func_name << "(){\n";
692-
os_ << " odla_DestroyComputation(Comp);\n";
690+
os_ << "int " << fini_func_name << "(){\n";
691+
os_ << " if (Comp !=" << EmitNull() << ") {";
692+
os_ << " return odla_DestroyComputation(Comp);}\n";
693+
os_ << " return ODLA_SUCCESS;\n";
693694
os_ << "}\n";
694695

695-
os_ << "void " << init_func_name << "(){\n";
696+
os_ << "int " << init_func_name << "(){\n";
696697
} else {
697698
os_ << GetFunctionDecl(function, *return_inst, true, true, true)
698699
<< " {\n";
699700
}
700-
os_ << " if (Comp == " << EmitNull() << ") { Comp = " << helper_func_name
701-
<< "(); }\n";
701+
os_ << " odla_status status = ODLA_SUCCESS;\n";
702+
os_ << " if (Comp == " << EmitNull() << ") { \n";
703+
os_ << " status = odla_CreateComputation(&Comp);\n";
704+
os_ << " " << status_check << "\n";
705+
os_ << " status = (odla_status)" << helper_func_name << "(Comp);\n";
706+
os_ << " }\n";
707+
os_ << " return status;\n";
702708
os_ << "}\n";
703709
}
710+
704711
if (function.IsEntryFunction()) {
705712
os_ << GetFunctionDecl(function, *return_inst, true, true, true)
706713
<< " {\n";
707714
if (opts_.exec_mode == ExecMode::Compile) {
708-
os_ << " " << init_func_name << "();\n";
715+
os_ << " odla_status status = ODLA_SUCCESS;\n";
716+
os_ << " status = (odla_status)" << init_func_name << "();\n";
717+
os_ << " " << status_check << "\n";
709718
}
710719
}
711720
if (opts_.exec_mode == ExecMode::Interpret) {
@@ -740,11 +749,14 @@ void GenericCXXCodeGen::RunOnFunction(Function& function) {
740749

741750
if (opts_.exec_mode == ExecMode::Compile) {
742751
os_ << " static odla_context Ctx;\n";
743-
os_ << " if (Ctx == " << EmitNull()
744-
<< ") { odla_CreateContext(&Ctx); };\n";
752+
os_ << " if (Ctx == " << EmitNull() << ") {";
753+
os_ << " status = odla_CreateContext(&Ctx);\n";
754+
os_ << " " << status_check << "\n";
755+
os_ << " }\n";
745756
if (opts_.emit_dynamic_batch) {
746-
os_ << "odla_SetContextItem(Ctx, ODLA_RUN_BATCH_SIZE, "
757+
os_ << " status = odla_SetContextItem(Ctx, ODLA_RUN_BATCH_SIZE, "
747758
"(odla_item_value) &batch_size);\n";
759+
os_ << " " << status_check << "\n";
748760
}
749761
}
750762
int index = 0;
@@ -755,11 +767,13 @@ void GenericCXXCodeGen::RunOnFunction(Function& function) {
755767
? (is_sub ? "inputs.values[" : "inputs[") +
756768
std::to_string(index++) + "]"
757769
: cv.name;
758-
os_ << (is_sub ? " odla_BindValueToArgumentById("
770+
os_ << " status = "
771+
<< (is_sub ? " odla_BindValueToArgumentById("
759772
: " odla_BindToArgumentById(")
760773
<< Join("(const odla_value_id)\"" + arg->GetName() + "\"", arg_name,
761774
"Ctx")
762775
<< ");\n";
776+
os_ << " " << status_check << "\n";
763777
}
764778
index = 0;
765779
// Pre-launch binding.
@@ -769,12 +783,13 @@ void GenericCXXCodeGen::RunOnFunction(Function& function) {
769783
? (is_sub ? "outputs.values[" : "outputs[") +
770784
std::to_string(index++) + "]"
771785
: "out_" + cv.name;
772-
os_ << " odla_Bind" << (is_sub ? "Value" : "") << "ToOutputById("
786+
os_ << " status = odla_Bind" << (is_sub ? "Value" : "") << "ToOutputById("
773787
<< Join("(const odla_value_id)\"" + cv.name + "\"", arg_name, "Ctx")
774788
<< ");\n";
789+
os_ << " " << status_check << "\n";
775790
}
776791
if (opts_.exec_mode == ExecMode::Compile) {
777-
os_ << " odla_ExecuteComputation(Comp, Ctx, "
792+
os_ << " return odla_ExecuteComputation(Comp, Ctx, "
778793
"ODLA_COMPUTE_INFERENCE, "
779794
<< EmitNull() << ");\n";
780795
}

tests/compile/test_cxx_gen.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737

3838
// GEN: static odla_computation Comp;
3939

40-
// GEN: void func(const float input[3], float out_add1[3]) {
40+
// GEN: int func(const float input[3], float out_add1[3]) {
4141
// GEN: func_init();
4242
// GEN: odla_BindToArgumentById((const odla_value_id)"input", input, Ctx);
4343
// GEN: odla_BindToOutputById((const odla_value_id)"add1", out_add1, Ctx);

tests/compile/test_cxx_gen_gpu.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434

3535
// GEN: static odla_computation Comp;
3636

37-
// GEN: void func(const float input[3], float out_add1[3]) {
37+
// GEN: int func(const float input[3], float out_add1[3]) {
3838
// GEN: func_init();
3939
// GEN: odla_BindToArgumentById((const odla_value_id)"input", input, Ctx);
4040
// GEN: odla_BindToOutputById((const odla_value_id)"add1", out_add1, Ctx);

0 commit comments

Comments
 (0)