@@ -208,7 +208,7 @@ std::string GenericCXXCodeGen::GetFunctionDecl(const Function& func,
208
208
bool with_type,
209
209
bool public_function) {
210
210
const static std::string inference_func_decl =
211
- " void model_run(int num_inputs, const void* inputs[],"
211
+ " int model_run(int num_inputs, const void* inputs[],"
212
212
" int num_outputs, void* outputs[], int batch_size)" ;
213
213
if (opts_.emit_inference_func_sig && func.IsEntryFunction () &&
214
214
public_function) {
@@ -221,7 +221,7 @@ std::string GenericCXXCodeGen::GetFunctionDecl(const Function& func,
221
221
ss << " static " ;
222
222
}
223
223
if (with_func_name) {
224
- ss << " void " << NormalizeVariableName (func.GetName ());
224
+ ss << " int " << NormalizeVariableName (func.GetName ());
225
225
}
226
226
ss << " (" ;
227
227
if (is_sub) {
@@ -612,10 +612,13 @@ void GenericCXXCodeGen::RunOnFunction(Function& function) {
612
612
oss << " extern \" C\" {\n " ;
613
613
}
614
614
oss << " " << func_decl << " ;\n " ;
615
- oss << " void " << init_func_name << " ();\n " ;
616
- oss << " void " << fini_func_name << " ();\n " ;
617
- oss << (is_compile_mode ? " odla_computation " : " static void " )
618
- << helper_func_name << " ();\n " ;
615
+ oss << " int " << init_func_name << " ();\n " ;
616
+ oss << " int " << fini_func_name << " ();\n " ;
617
+ if (is_compile_mode) {
618
+ oss << " int " << helper_func_name << " (odla_computation comp);\n " ;
619
+ } else {
620
+ oss << " static void " << helper_func_name << " ()\n ;" ;
621
+ }
619
622
if (opts_.dialect == Dialect::CXX_11) {
620
623
oss << " };\n " ;
621
624
}
@@ -626,9 +629,7 @@ void GenericCXXCodeGen::RunOnFunction(Function& function) {
626
629
if (emit_builder_func) {
627
630
if (is_compile_mode) {
628
631
os_ << " static odla_computation Comp;\n " ;
629
- os_ << " odla_computation " << helper_func_name << " () {\n " ;
630
- os_ << " odla_computation comp;\n " ;
631
- os_ << " odla_CreateComputation(&comp);\n " ;
632
+ os_ << " int " << helper_func_name << " (odla_computation comp) {\n " ;
632
633
EmitComputationItems (&os_, opts_);
633
634
} else {
634
635
os_ << " static void " << helper_func_name
@@ -641,11 +642,6 @@ void GenericCXXCodeGen::RunOnFunction(Function& function) {
641
642
os_ << " odla_SetCurrentDevice(device);" ;
642
643
}
643
644
644
- if (is_compile_mode) {
645
- os_ << " static odla_computation comp;\n " ;
646
- os_ << " if (comp == " << EmitNull () << " ) {\n " ;
647
- os_ << " odla_CreateComputation(&comp);\n " ;
648
- }
649
645
EmitComputationItems (&os_, opts_);
650
646
}
651
647
@@ -675,7 +671,7 @@ void GenericCXXCodeGen::RunOnFunction(Function& function) {
675
671
RunOnBasicBlock (*bb);
676
672
}
677
673
if (is_compile_mode) {
678
- os_ << " return comp ;\n " ;
674
+ os_ << " return ODLA_SUCCESS ;\n " ;
679
675
}
680
676
681
677
os_ << " }\n " ; // End of computation build function.
@@ -684,28 +680,41 @@ void GenericCXXCodeGen::RunOnFunction(Function& function) {
684
680
dynamic_check_os_ << GenerateTestFunc (function, func_decl, *return_inst);
685
681
}
686
682
683
+ const std::string& status_check{
684
+ " if (status != ODLA_SUCCESS) { return status;}" };
685
+
687
686
if (emit_builder_func) {
688
687
// Emit function for launching computation.
689
688
if (opts_.exec_mode == ExecMode::Compile) {
690
689
if (function.IsEntryFunction ()) {
691
- os_ << " void " << fini_func_name << " (){\n " ;
692
- os_ << " odla_DestroyComputation(Comp);\n " ;
690
+ os_ << " int " << fini_func_name << " (){\n " ;
691
+ os_ << " if (Comp !=" << EmitNull () << " ) {" ;
692
+ os_ << " return odla_DestroyComputation(Comp);}\n " ;
693
+ os_ << " return ODLA_SUCCESS;\n " ;
693
694
os_ << " }\n " ;
694
695
695
- os_ << " void " << init_func_name << " (){\n " ;
696
+ os_ << " int " << init_func_name << " (){\n " ;
696
697
} else {
697
698
os_ << GetFunctionDecl (function, *return_inst, true , true , true )
698
699
<< " {\n " ;
699
700
}
700
- os_ << " if (Comp == " << EmitNull () << " ) { Comp = " << helper_func_name
701
- << " (); }\n " ;
701
+ os_ << " odla_status status = ODLA_SUCCESS;\n " ;
702
+ os_ << " if (Comp == " << EmitNull () << " ) { \n " ;
703
+ os_ << " status = odla_CreateComputation(&Comp);\n " ;
704
+ os_ << " " << status_check << " \n " ;
705
+ os_ << " status = (odla_status)" << helper_func_name << " (Comp);\n " ;
706
+ os_ << " }\n " ;
707
+ os_ << " return status;\n " ;
702
708
os_ << " }\n " ;
703
709
}
710
+
704
711
if (function.IsEntryFunction ()) {
705
712
os_ << GetFunctionDecl (function, *return_inst, true , true , true )
706
713
<< " {\n " ;
707
714
if (opts_.exec_mode == ExecMode::Compile) {
708
- os_ << " " << init_func_name << " ();\n " ;
715
+ os_ << " odla_status status = ODLA_SUCCESS;\n " ;
716
+ os_ << " status = (odla_status)" << init_func_name << " ();\n " ;
717
+ os_ << " " << status_check << " \n " ;
709
718
}
710
719
}
711
720
if (opts_.exec_mode == ExecMode::Interpret) {
@@ -740,11 +749,14 @@ void GenericCXXCodeGen::RunOnFunction(Function& function) {
740
749
741
750
if (opts_.exec_mode == ExecMode::Compile) {
742
751
os_ << " static odla_context Ctx;\n " ;
743
- os_ << " if (Ctx == " << EmitNull ()
744
- << " ) { odla_CreateContext(&Ctx); };\n " ;
752
+ os_ << " if (Ctx == " << EmitNull () << " ) {" ;
753
+ os_ << " status = odla_CreateContext(&Ctx);\n " ;
754
+ os_ << " " << status_check << " \n " ;
755
+ os_ << " }\n " ;
745
756
if (opts_.emit_dynamic_batch ) {
746
- os_ << " odla_SetContextItem(Ctx, ODLA_RUN_BATCH_SIZE, "
757
+ os_ << " status = odla_SetContextItem(Ctx, ODLA_RUN_BATCH_SIZE, "
747
758
" (odla_item_value) &batch_size);\n " ;
759
+ os_ << " " << status_check << " \n " ;
748
760
}
749
761
}
750
762
int index = 0 ;
@@ -755,11 +767,13 @@ void GenericCXXCodeGen::RunOnFunction(Function& function) {
755
767
? (is_sub ? " inputs.values[" : " inputs[" ) +
756
768
std::to_string (index ++) + " ]"
757
769
: cv.name ;
758
- os_ << (is_sub ? " odla_BindValueToArgumentById("
770
+ os_ << " status = "
771
+ << (is_sub ? " odla_BindValueToArgumentById("
759
772
: " odla_BindToArgumentById(" )
760
773
<< Join (" (const odla_value_id)\" " + arg->GetName () + " \" " , arg_name,
761
774
" Ctx" )
762
775
<< " );\n " ;
776
+ os_ << " " << status_check << " \n " ;
763
777
}
764
778
index = 0 ;
765
779
// Pre-launch binding.
@@ -769,12 +783,13 @@ void GenericCXXCodeGen::RunOnFunction(Function& function) {
769
783
? (is_sub ? " outputs.values[" : " outputs[" ) +
770
784
std::to_string (index ++) + " ]"
771
785
: " out_" + cv.name ;
772
- os_ << " odla_Bind" << (is_sub ? " Value" : " " ) << " ToOutputById("
786
+ os_ << " status = odla_Bind" << (is_sub ? " Value" : " " ) << " ToOutputById("
773
787
<< Join (" (const odla_value_id)\" " + cv.name + " \" " , arg_name, " Ctx" )
774
788
<< " );\n " ;
789
+ os_ << " " << status_check << " \n " ;
775
790
}
776
791
if (opts_.exec_mode == ExecMode::Compile) {
777
- os_ << " odla_ExecuteComputation(Comp, Ctx, "
792
+ os_ << " return odla_ExecuteComputation(Comp, Ctx, "
778
793
" ODLA_COMPUTE_INFERENCE, "
779
794
<< EmitNull () << " );\n " ;
780
795
}
0 commit comments