@@ -98,17 +98,24 @@ void merge_single_model_with_predicate(std::string& onnx_model_path, std::string
98
98
*onnx_model.mutable_graph ()->add_node ()=reshape_node;
99
99
*onnx_model.mutable_graph ()->add_input ()=input;
100
100
*onnx_model.mutable_graph ()->add_node ()=node;
101
+ std::regex pattern_output_label (" _output_label$" );
101
102
std::vector<onnx::ValueInfoProto> output_list;
103
+ std::vector<onnx::ValueInfoProto> save_output_list;
102
104
for (int i=0 ;i<onnx_model.graph ().output_size ();i++){
103
- if (!std::regex_search (onnx_model.graph ().output (i).name (), pattern)){
105
+ if (!std::regex_search (onnx_model.graph ().output (i).name (), pattern))
104
106
output_list.push_back (onnx_model.graph ().output (i));
105
- }
107
+ else if (std::regex_search (onnx_model.graph ().output (i).name (), pattern_output_label))
108
+ save_output_list.push_back (onnx_model.graph ().output (i));
106
109
}
107
110
onnx_model.mutable_graph ()->clear_output ();
111
+ *onnx_model.mutable_graph ()->add_output ()=output;
108
112
for (int i=0 ;i<output_list.size ();i++){
109
113
*onnx_model.mutable_graph ()->add_output ()=output_list.at (i);
110
114
}
111
- *onnx_model.mutable_graph ()->add_output ()=output;
115
+ for (int i=0 ;i<save_output_list.size ();i++){
116
+ *onnx_model.mutable_graph ()->add_output ()=save_output_list.at (i);
117
+ }
118
+
112
119
113
120
onnx::checker::check_model (onnx_model);
114
121
saveModel (&onnx_model,onnx_model_path);
@@ -139,7 +146,7 @@ void merge_double_models_with_predicate(std::string& onnx_model_path,std::string
139
146
onnx::ValueInfoProto output_l;
140
147
onnx::ValueInfoProto output_r;
141
148
142
- for (int i=0 ;i< onnx_model.graph ().output_size ();i++ ){
149
+ for (int i=onnx_model.graph ().output_size ()- 1 ;i>= 0 ;i-- ){
143
150
if (std::regex_search (onnx_model.graph ().output (i).name (), pattern_l)&&
144
151
!std::regex_search (onnx_model.graph ().output (i).name (), pattern_end)&&
145
152
!std::regex_search (onnx_model.graph ().output (i).name (), pattern_ends)){
@@ -166,18 +173,25 @@ void merge_double_models_with_predicate(std::string& onnx_model_path,std::string
166
173
167
174
// change onnx_model
168
175
*onnx_model.mutable_graph ()->add_node ()=node;
176
+ std::regex pattern_output_label (" _output_label$" );
169
177
std::vector<onnx::ValueInfoProto> output_list;
178
+ std::vector<onnx::ValueInfoProto> save_output_list;
170
179
for (int i=0 ;i<onnx_model.graph ().output_size ();i++){
171
180
if (!std::regex_search (onnx_model.graph ().output (i).name (), pattern_l) &&
172
181
!std::regex_search (onnx_model.graph ().output (i).name (), pattern_r)){
173
182
output_list.push_back (onnx_model.graph ().output (i));
174
- }
183
+ } else if (std::regex_search (onnx_model.graph ().output (i).name (), pattern_output_label))
184
+ save_output_list.push_back (onnx_model.graph ().output (i));
175
185
}
176
186
onnx_model.mutable_graph ()->clear_output ();
187
+ *onnx_model.mutable_graph ()->add_output ()=output;
177
188
for (int i=0 ;i<output_list.size ();i++){
178
189
*onnx_model.mutable_graph ()->add_output ()=output_list.at (i);
179
190
}
180
- *onnx_model.mutable_graph ()->add_output ()=output;
191
+ for (int i=0 ;i<save_output_list.size ();i++){
192
+ *onnx_model.mutable_graph ()->add_output ()=save_output_list.at (i);
193
+ }
194
+
181
195
182
196
onnx::checker::check_model (onnx_model);
183
197
saveModel (&onnx_model,onnx_model_path);
0 commit comments