Skip to content

Commit 86cb9be

Browse files
committed
Merge branch 'opted_with_output'
# Conflicts: # examples/onnx_output_model/model_opted.onnx # examples/optimizer_c_example/predicate_example.cpp
2 parents 0dbd417 + fd2ed8c commit 86cb9be

File tree

1 file changed

+20
-6
lines changed

1 file changed

+20
-6
lines changed

Diff for: onnxoptimizer/query_c_api/predicate_push_down.cpp

+20-6
Original file line numberDiff line numberDiff line change
@@ -100,17 +100,24 @@ void merge_single_model_with_predicate(std::string& onnx_model_path, std::string
100100
*onnx_model.mutable_graph()->add_node()=reshape_node;
101101
*onnx_model.mutable_graph()->add_input()=input;
102102
*onnx_model.mutable_graph()->add_node()=node;
103+
std::regex pattern_output_label("_output_label$");
103104
std::vector<onnx::ValueInfoProto> output_list;
105+
std::vector<onnx::ValueInfoProto> save_output_list;
104106
for(int i=0;i<onnx_model.graph().output_size();i++){
105-
if(!std::regex_search(onnx_model.graph().output(i).name(), pattern)){
107+
if(!std::regex_search(onnx_model.graph().output(i).name(), pattern))
106108
output_list.push_back(onnx_model.graph().output(i));
107-
}
109+
else if(std::regex_search(onnx_model.graph().output(i).name(), pattern_output_label))
110+
save_output_list.push_back(onnx_model.graph().output(i));
108111
}
109112
onnx_model.mutable_graph()->clear_output();
113+
*onnx_model.mutable_graph()->add_output()=output;
110114
for(int i=0;i<output_list.size();i++){
111115
*onnx_model.mutable_graph()->add_output()=output_list.at(i);
112116
}
113-
*onnx_model.mutable_graph()->add_output()=output;
117+
for(int i=0;i<save_output_list.size();i++){
118+
*onnx_model.mutable_graph()->add_output()=save_output_list.at(i);
119+
}
120+
114121

115122
onnx::checker::check_model(onnx_model);
116123
saveModel(&onnx_model,onnx_model_path);
@@ -141,7 +148,7 @@ void merge_double_models_with_predicate(std::string& onnx_model_path,std::string
141148
onnx::ValueInfoProto output_l;
142149
onnx::ValueInfoProto output_r;
143150

144-
for(int i=0;i<onnx_model.graph().output_size();i++){
151+
for(int i=onnx_model.graph().output_size()-1;i>=0;i--){
145152
if(std::regex_search(onnx_model.graph().output(i).name(), pattern_l)&&
146153
!std::regex_search(onnx_model.graph().output(i).name(), pattern_end)&&
147154
!std::regex_search(onnx_model.graph().output(i).name(), pattern_ends)){
@@ -168,18 +175,25 @@ void merge_double_models_with_predicate(std::string& onnx_model_path,std::string
168175

169176
// change onnx_model
170177
*onnx_model.mutable_graph()->add_node()=node;
178+
std::regex pattern_output_label("_output_label$");
171179
std::vector<onnx::ValueInfoProto> output_list;
180+
std::vector<onnx::ValueInfoProto> save_output_list;
172181
for(int i=0;i<onnx_model.graph().output_size();i++){
173182
if(!std::regex_search(onnx_model.graph().output(i).name(), pattern_l) &&
174183
!std::regex_search(onnx_model.graph().output(i).name(), pattern_r)){
175184
output_list.push_back(onnx_model.graph().output(i));
176-
}
185+
} else if(std::regex_search(onnx_model.graph().output(i).name(), pattern_output_label))
186+
save_output_list.push_back(onnx_model.graph().output(i));
177187
}
178188
onnx_model.mutable_graph()->clear_output();
189+
*onnx_model.mutable_graph()->add_output()=output;
179190
for(int i=0;i<output_list.size();i++){
180191
*onnx_model.mutable_graph()->add_output()=output_list.at(i);
181192
}
182-
*onnx_model.mutable_graph()->add_output()=output;
193+
for(int i=0;i<save_output_list.size();i++){
194+
*onnx_model.mutable_graph()->add_output()=save_output_list.at(i);
195+
}
196+
183197

184198
onnx::checker::check_model(onnx_model);
185199
saveModel(&onnx_model,onnx_model_path);

0 commit comments

Comments
 (0)