Skip to content

Commit fd2ed8c

Browse files
committed
save model output_label
1 parent 4e35c3a commit fd2ed8c

File tree

3 files changed

+27
-14
lines changed

3 files changed

+27
-14
lines changed

Diff for: examples/onnx_output_model/model_opted.onnx

897 Bytes
Binary file not shown.

Diff for: examples/optimizer_c_example/predicate_example.cpp

+7-8
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,19 @@ int main(int argc, char* argv[]) {
1515
std::string type2="string";
1616
std::string name1="sgd_1";
1717
std::string name2="lr_1";
18-
std::string name3="model_lr_1_model_linear_1";
19-
std::string name4="model_lr_2";
18+
std::string name3="sgd_1_lr_1";
19+
std::string name4="nb_1";
2020
std::string name5="model_lr_1_model_linear_1_model_lr_2";
21-
// merge_single_model_with_predicate(path,predicate2,type,name2);
22-
// merge_single_model_with_predicate(path,predicate1,type,name1);
21+
2322
// merge_double_models_with_predicate(path,predicate3,name1,name2);
24-
// merge_single_model_with_predicate(path,predicate4,type,name3);
25-
// merge_double_models_with_predicate(path,predicate2,name1,name2);
26-
// merge_single_model_with_predicate(path,predicate4,type,name3);
23+
2724

2825
merge_single_model_with_predicate(path,predicate1,type,name1,1);
29-
//merge_single_model_with_predicate(path,predicate1,type,name1,2);
3026
merge_single_model_with_predicate(path,predicate1,type,name2,1);
3127
merge_double_models_with_predicate(path,predicate3,name1,name2);
28+
merge_single_model_with_predicate(path,predicate1,type,name4,1);
29+
merge_double_models_with_predicate(path,predicate2,name3,name4);
30+
3231
// merge_single_model_with_predicate(path,predicate1,type,name4,1);
3332
// merge_double_models_with_predicate(path,predicate2,name3,name4);
3433

Diff for: onnxoptimizer/query_c_api/predicate_push_down.cpp

+20-6
Original file line numberDiff line numberDiff line change
@@ -98,17 +98,24 @@ void merge_single_model_with_predicate(std::string& onnx_model_path, std::string
9898
*onnx_model.mutable_graph()->add_node()=reshape_node;
9999
*onnx_model.mutable_graph()->add_input()=input;
100100
*onnx_model.mutable_graph()->add_node()=node;
101+
std::regex pattern_output_label("_output_label$");
101102
std::vector<onnx::ValueInfoProto> output_list;
103+
std::vector<onnx::ValueInfoProto> save_output_list;
102104
for(int i=0;i<onnx_model.graph().output_size();i++){
103-
if(!std::regex_search(onnx_model.graph().output(i).name(), pattern)){
105+
if(!std::regex_search(onnx_model.graph().output(i).name(), pattern))
104106
output_list.push_back(onnx_model.graph().output(i));
105-
}
107+
else if(std::regex_search(onnx_model.graph().output(i).name(), pattern_output_label))
108+
save_output_list.push_back(onnx_model.graph().output(i));
106109
}
107110
onnx_model.mutable_graph()->clear_output();
111+
*onnx_model.mutable_graph()->add_output()=output;
108112
for(int i=0;i<output_list.size();i++){
109113
*onnx_model.mutable_graph()->add_output()=output_list.at(i);
110114
}
111-
*onnx_model.mutable_graph()->add_output()=output;
115+
for(int i=0;i<save_output_list.size();i++){
116+
*onnx_model.mutable_graph()->add_output()=save_output_list.at(i);
117+
}
118+
112119

113120
onnx::checker::check_model(onnx_model);
114121
saveModel(&onnx_model,onnx_model_path);
@@ -139,7 +146,7 @@ void merge_double_models_with_predicate(std::string& onnx_model_path,std::string
139146
onnx::ValueInfoProto output_l;
140147
onnx::ValueInfoProto output_r;
141148

142-
for(int i=0;i<onnx_model.graph().output_size();i++){
149+
for(int i=onnx_model.graph().output_size()-1;i>=0;i--){
143150
if(std::regex_search(onnx_model.graph().output(i).name(), pattern_l)&&
144151
!std::regex_search(onnx_model.graph().output(i).name(), pattern_end)&&
145152
!std::regex_search(onnx_model.graph().output(i).name(), pattern_ends)){
@@ -166,18 +173,25 @@ void merge_double_models_with_predicate(std::string& onnx_model_path,std::string
166173

167174
// change onnx_model
168175
*onnx_model.mutable_graph()->add_node()=node;
176+
std::regex pattern_output_label("_output_label$");
169177
std::vector<onnx::ValueInfoProto> output_list;
178+
std::vector<onnx::ValueInfoProto> save_output_list;
170179
for(int i=0;i<onnx_model.graph().output_size();i++){
171180
if(!std::regex_search(onnx_model.graph().output(i).name(), pattern_l) &&
172181
!std::regex_search(onnx_model.graph().output(i).name(), pattern_r)){
173182
output_list.push_back(onnx_model.graph().output(i));
174-
}
183+
} else if(std::regex_search(onnx_model.graph().output(i).name(), pattern_output_label))
184+
save_output_list.push_back(onnx_model.graph().output(i));
175185
}
176186
onnx_model.mutable_graph()->clear_output();
187+
*onnx_model.mutable_graph()->add_output()=output;
177188
for(int i=0;i<output_list.size();i++){
178189
*onnx_model.mutable_graph()->add_output()=output_list.at(i);
179190
}
180-
*onnx_model.mutable_graph()->add_output()=output;
191+
for(int i=0;i<save_output_list.size();i++){
192+
*onnx_model.mutable_graph()->add_output()=save_output_list.at(i);
193+
}
194+
181195

182196
onnx::checker::check_model(onnx_model);
183197
saveModel(&onnx_model,onnx_model_path);

0 commit comments

Comments
 (0)