@@ -100,17 +100,24 @@ void merge_single_model_with_predicate(std::string& onnx_model_path, std::string
100
100
*onnx_model.mutable_graph ()->add_node ()=reshape_node;
101
101
*onnx_model.mutable_graph ()->add_input ()=input;
102
102
*onnx_model.mutable_graph ()->add_node ()=node;
103
+ std::regex pattern_output_label (" _output_label$" );
103
104
std::vector<onnx::ValueInfoProto> output_list;
105
+ std::vector<onnx::ValueInfoProto> save_output_list;
104
106
for (int i=0 ;i<onnx_model.graph ().output_size ();i++){
105
- if (!std::regex_search (onnx_model.graph ().output (i).name (), pattern)){
107
+ if (!std::regex_search (onnx_model.graph ().output (i).name (), pattern))
106
108
output_list.push_back (onnx_model.graph ().output (i));
107
- }
109
+ else if (std::regex_search (onnx_model.graph ().output (i).name (), pattern_output_label))
110
+ save_output_list.push_back (onnx_model.graph ().output (i));
108
111
}
109
112
onnx_model.mutable_graph ()->clear_output ();
113
+ *onnx_model.mutable_graph ()->add_output ()=output;
110
114
for (int i=0 ;i<output_list.size ();i++){
111
115
*onnx_model.mutable_graph ()->add_output ()=output_list.at (i);
112
116
}
113
- *onnx_model.mutable_graph ()->add_output ()=output;
117
+ for (int i=0 ;i<save_output_list.size ();i++){
118
+ *onnx_model.mutable_graph ()->add_output ()=save_output_list.at (i);
119
+ }
120
+
114
121
115
122
onnx::checker::check_model (onnx_model);
116
123
saveModel (&onnx_model,onnx_model_path);
@@ -141,7 +148,7 @@ void merge_double_models_with_predicate(std::string& onnx_model_path,std::string
141
148
onnx::ValueInfoProto output_l;
142
149
onnx::ValueInfoProto output_r;
143
150
144
- for (int i=0 ;i< onnx_model.graph ().output_size ();i++ ){
151
+ for (int i=onnx_model.graph ().output_size ()- 1 ;i>= 0 ;i-- ){
145
152
if (std::regex_search (onnx_model.graph ().output (i).name (), pattern_l)&&
146
153
!std::regex_search (onnx_model.graph ().output (i).name (), pattern_end)&&
147
154
!std::regex_search (onnx_model.graph ().output (i).name (), pattern_ends)){
@@ -168,18 +175,25 @@ void merge_double_models_with_predicate(std::string& onnx_model_path,std::string
168
175
169
176
// change onnx_model
170
177
*onnx_model.mutable_graph ()->add_node ()=node;
178
+ std::regex pattern_output_label (" _output_label$" );
171
179
std::vector<onnx::ValueInfoProto> output_list;
180
+ std::vector<onnx::ValueInfoProto> save_output_list;
172
181
for (int i=0 ;i<onnx_model.graph ().output_size ();i++){
173
182
if (!std::regex_search (onnx_model.graph ().output (i).name (), pattern_l) &&
174
183
!std::regex_search (onnx_model.graph ().output (i).name (), pattern_r)){
175
184
output_list.push_back (onnx_model.graph ().output (i));
176
- }
185
+ } else if (std::regex_search (onnx_model.graph ().output (i).name (), pattern_output_label))
186
+ save_output_list.push_back (onnx_model.graph ().output (i));
177
187
}
178
188
onnx_model.mutable_graph ()->clear_output ();
189
+ *onnx_model.mutable_graph ()->add_output ()=output;
179
190
for (int i=0 ;i<output_list.size ();i++){
180
191
*onnx_model.mutable_graph ()->add_output ()=output_list.at (i);
181
192
}
182
- *onnx_model.mutable_graph ()->add_output ()=output;
193
+ for (int i=0 ;i<save_output_list.size ();i++){
194
+ *onnx_model.mutable_graph ()->add_output ()=save_output_list.at (i);
195
+ }
196
+
183
197
184
198
onnx::checker::check_model (onnx_model);
185
199
saveModel (&onnx_model,onnx_model_path);
0 commit comments