Skip to content

Commit 69cda96

Browse files
committed
feat: add prefix on compared model
1 parent a1152f7 commit 69cda96

File tree

11 files changed

+31
-15
lines changed

11 files changed

+31
-15
lines changed
19.9 KB
Binary file not shown.
123 KB
Binary file not shown.
-7.93 MB
Binary file not shown.
53 KB
Binary file not shown.
53.7 KB
Binary file not shown.
-461 KB
Binary file not shown.

examples/optimizer_c_example/optimize_example.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ int main(int argc, char* argv[]) {
2020
// std::string path1="../examples/onnx_input_model/model_lr.onnx";
2121
// std::string path2="../examples/onnx_input_model/model_linear.onnx";
2222

23-
std::string path1="../examples/onnx_input_model/news_nb.onnx";
24-
std::string path2="../examples/onnx_input_model/news_lr.onnx";
25-
std::string path3="../examples/onnx_input_model/news_sgd.onnx";
23+
std::string path1="../examples/onnx_input_model/model_nb.onnx";
24+
std::string path2="../examples/onnx_input_model/model_lr.onnx";
25+
std::string path3="../examples/onnx_input_model/model_sgd.onnx";
2626
// std::string pre1="model_lr_1_";
2727
// std::string pre2="model_linear_1_";
2828
std::string pre1="nb_1_";

examples/optimizer_c_example/redundant.cpp

+9-5
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,16 @@ int main(int argc, char* argv[]) {
1414
std::string path7="../examples/onnx_input_model/model_nb.onnx";
1515
std::string path8="../examples/onnx_input_model/flights_lr.onnx";
1616
std::string path9="../examples/onnx_input_model/flights_nb.onnx";
17-
std::string path10="../examples/onnx_input_model/fakenews_lr.onnx";
18-
std::string path11="../examples/onnx_input_model/fakenews_nb.onnx";
17+
std::string path10="../examples/onnx_input_model/fakenews_lr_test.onnx";
18+
std::string path11="../examples/onnx_input_model/fakenews_nb";
19+
std::string path12="../examples/onnx_input_model/model_opted.onnx";
1920
std::string outpath="../examples/onnx_input_model/model_out.onnx";
2021
std::string inpath="../examples/onnx_input_model/model_in.onnx";
21-
22-
std::vector<std::string> tmp = check_redundant(path10,path11);
23-
change_models(path10,outpath,inpath,tmp);
22+
std::string path13="../examples/onnx_input_model/model_reg.onnx";
23+
std::string pre="sgd_1_";
24+
std::string path_pre="../examples/onnx_input_model/model_reg_prefix.onnx";
25+
add_prefix_on_model(path13, path_pre, pre);
26+
std::vector<std::string> tmp = check_redundant(path_pre,path12);
27+
change_models(path_pre,outpath,inpath,tmp);
2428

2529
}

onnxoptimizer/optimize_c_api/optimize_c_api.cpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#include <fstream>
66

77
#include "optimize_c_api.h"
8-
#include "onnxoptimizer/query_c_api/model_merge.cpp"
8+
//#include "onnxoptimizer/query_c_api/model_merge.cpp"
99
#include "onnxoptimizer/query_c_api/predicate_push_down.cpp"
1010
#include "onnxoptimizer/query_c_api/redundant_calculation_detection.cpp"
1111

@@ -49,3 +49,6 @@ void change_models(std::string& changed_model_path,std::string& output_model_pat
4949
const std::vector<std::string>& output_name){
5050
onnx::optimization::change_models(changed_model_path, output_model_path, changed_input_model_path, output_name);
5151
}
52+
void add_prefix_on_model(std::string& changed_model_path, std::string& output_model_path, std::string& prefix){
53+
onnx::optimization::add_prefix_on_model(changed_model_path, output_model_path, prefix);
54+
}

onnxoptimizer/optimize_c_api/optimize_c_api.h

+1
Original file line numberDiff line numberDiff line change
@@ -32,5 +32,6 @@ std::vector<std::string> check_redundant(std::string& changed_model_path,
3232
void change_models(std::string& changed_model_path,std::string& output_model_path,
3333
std::string& changed_input_model_path,
3434
const std::vector<std::string>& output_name);
35+
void add_prefix_on_model(std::string& changed_model_path, std::string& output_model_path, std::string& prefix);
3536
#endif // ONNX_OPTIMIZER_OPTIMIZE_C_API_H
3637

onnxoptimizer/query_c_api/redundant_calculation_detection.cpp

+14-6
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <onnx/checker.h>
55
#include <onnx/onnx_pb.h>
66
#include <onnxoptimizer/optimize.h>
7+
#include <onnxoptimizer/query_c_api/model_merge.cpp>
78
#include <onnxoptimizer/model_util.h>
89
#include <vector>
910
#include "onnx/common/ir.h"
@@ -13,6 +14,13 @@
1314
#include "onnxoptimizer/pass_manager.h"
1415
#include "onnxoptimizer/pass_registry.h"
1516
namespace onnx::optimization {
17+
void add_prefix_on_model(std::string& changed_model_path, std::string& output_model_path, std::string& prefix){
18+
ModelProto model_changed, model_output;
19+
onnx::optimization::loadModel(&model_changed, changed_model_path, true);
20+
add_model_prefix(&model_changed,prefix,&model_output);
21+
saveModel(&model_output,output_model_path);
22+
}
23+
1624
void mark_reachable_nodes(std::shared_ptr<Graph> graph, const std::string& input_name, std::unordered_set<std::string>& reachable_nodes) {
1725
for(auto node:graph->nodes()){
1826
if(reachable_nodes.find(node->name())!=reachable_nodes.end())
@@ -32,7 +40,6 @@ namespace onnx::optimization {
3240
ModelProto model_changed, model_compared;
3341
onnx::optimization::loadModel(&model_changed, changed_model_path, true);
3442
onnx::optimization::loadModel(&model_compared, compared_model_path, true);
35-
3643
std::shared_ptr<Graph> g_changed(ImportModelProto(model_changed));
3744
std::shared_ptr<Graph> g_compared(ImportModelProto(model_compared));
3845
auto node_list_changed = g_changed->nodes();
@@ -60,6 +67,7 @@ namespace onnx::optimization {
6067
}
6168
}
6269
std::vector<Node *> same_node_input_list(same_node_list);
70+
//将那些输出均有下一个相同结点将其作为输入的结点删掉
6371
for(int i = 0;i<same_node_list.size(); i++){
6472
auto node=same_node_list[i];
6573
auto outputs = node->outputs();
@@ -116,13 +124,13 @@ namespace onnx::optimization {
116124
}
117125

118126
std::vector<std::string> value_name_list;
119-
for(int i=0;i<final_node_list.size();i++){
120-
for(auto node: final_node_list){
121-
for(auto value:node->outputs()){
122-
value_name_list.push_back(value->uniqueName());
123-
}
127+
//for(int i=0;i<final_node_list.size();i++){
128+
for(auto node: final_node_list){
129+
for(auto value:node->outputs()){
130+
value_name_list.push_back(value->uniqueName());
124131
}
125132
}
133+
//}
126134
return value_name_list;
127135
}
128136

0 commit comments

Comments
 (0)