Skip to content

Commit a1152f7

Browse files
committed
feat: add new models and redundant_calculation_detection.cpp
1 parent c095f1d commit a1152f7

16 files changed

+241
-11
lines changed

Diff for: CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ target_link_libraries(optimizer_example optimize_c_api)
8282
onnxopt_add_executable(predicate_example examples/optimizer_c_example/predicate_example.cpp)
8383
target_link_libraries(predicate_example optimize_c_api)
8484

85+
onnxopt_add_executable(redundant_example examples/optimizer_c_example/redundant.cpp)
86+
target_link_libraries(redundant_example optimize_c_api)
8587

8688
#configure_file(onnxoptimizer/value_info/int_value_info.onnx
8789
# ${CMAKE_BINARY_DIR}/int_value_info.onnx COPYONLY)

Diff for: examples/onnx_input_model/flights_lr.onnx

138 KB
Binary file not shown.

Diff for: examples/onnx_input_model/flights_nb.onnx

125 KB
Binary file not shown.

Diff for: examples/onnx_input_model/flights_sgd.onnx

98.6 KB
Binary file not shown.

Diff for: examples/onnx_input_model/model_out.onnx

7.98 MB
Binary file not shown.

Diff for: examples/onnx_input_model/neo_lr.onnx

1.07 KB
Binary file not shown.

Diff for: examples/onnx_input_model/neo_lr2.onnx

1.2 KB
Binary file not shown.

Diff for: examples/onnx_input_model/neo_nb.onnx

1.76 KB
Binary file not shown.

Diff for: examples/onnx_input_model/neo_nb2.onnx

1.83 KB
Binary file not shown.

Diff for: examples/onnx_input_model/neo_sgd.onnx

1.45 KB
Binary file not shown.

Diff for: examples/onnx_input_model/neo_sgd2.onnx

1.5 KB
Binary file not shown.

Diff for: examples/onnx_output_model/model_opted.onnx

-177 KB
Binary file not shown.

Diff for: examples/optimizer_c_example/redundant.cpp

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
//
2+
// Created by xyyang's mac on 2024/7/7.
3+
//
4+
#include <string>
5+
#include "onnxoptimizer/optimize_c_api/optimize_c_api.h"
6+
7+
int main(int argc, char* argv[]) {
8+
9+
std::string path2="../examples/onnx_input_model/neo_nb2.onnx";
10+
std::string path3="../examples/onnx_input_model/neo_lr2.onnx";
11+
std::string path4="../examples/onnx_input_model/news_lr.onnx";
12+
std::string path5="../examples/onnx_input_model/news_nb.onnx";
13+
std::string path6="../examples/onnx_input_model/model_lr.onnx";
14+
std::string path7="../examples/onnx_input_model/model_nb.onnx";
15+
std::string path8="../examples/onnx_input_model/flights_lr.onnx";
16+
std::string path9="../examples/onnx_input_model/flights_nb.onnx";
17+
std::string path10="../examples/onnx_input_model/fakenews_lr.onnx";
18+
std::string path11="../examples/onnx_input_model/fakenews_nb.onnx";
19+
std::string outpath="../examples/onnx_input_model/model_out.onnx";
20+
std::string inpath="../examples/onnx_input_model/model_in.onnx";
21+
22+
std::vector<std::string> tmp = check_redundant(path10,path11);
23+
change_models(path10,outpath,inpath,tmp);
24+
25+
}

Diff for: onnxoptimizer/optimize_c_api/optimize_c_api.cpp

+9
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "optimize_c_api.h"
88
#include "onnxoptimizer/query_c_api/model_merge.cpp"
99
#include "onnxoptimizer/query_c_api/predicate_push_down.cpp"
10+
#include "onnxoptimizer/query_c_api/redundant_calculation_detection.cpp"
1011

1112
void optimize_with_model_path(
1213
std::string& mp_in_path1,
@@ -40,3 +41,11 @@ void optimize_on_merged_model(std::string& mp_in_path,std::string& mp_out_path){
4041
onnx::optimization::OptimizeOnMergedModel(mp_in_path, mp_out_path);
4142
}
4243

44+
std::vector<std::string> check_redundant(std::string& changed_model_path, std::string& compared_model_path){
45+
return onnx::optimization::check_redundant(changed_model_path, compared_model_path);
46+
}
47+
void change_models(std::string& changed_model_path,std::string& output_model_path,
48+
std::string& changed_input_model_path,
49+
const std::vector<std::string>& output_name){
50+
onnx::optimization::change_models(changed_model_path, output_model_path, changed_input_model_path, output_name);
51+
}

Diff for: onnxoptimizer/optimize_c_api/optimize_c_api.h

+22-11
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,31 @@
66
#define ONNX_OPTIMIZER_OPTIMIZE_C_API_H
77

88
void optimize_with_model_path(std::string& mp_in_path1,
9-
std::string& mp_in_path2,
10-
std::string& mp_name1,
11-
std::string& mp_name2,
12-
std::string& mp_out_path);
9+
std::string& mp_in_path2, std::string& mp_name1,
10+
std::string& mp_name2, std::string& mp_out_path);
11+
12+
void merge_single_model_with_predicate(std::string& onnx_model_path,
13+
std::string& predicate,
14+
std::string& value_type,
15+
std::string prefix, int count);
1316

14-
void merge_single_model_with_predicate(std::string& onnx_model_path, std::string& predicate,
15-
std::string& value_type, std::string prefix,int count);
17+
void merge_double_models_with_predicate(std::string& onnx_model_path,
18+
std::string& predicate,
19+
std::string prefix_l,
20+
std::string prefix_r);
1621

17-
void merge_double_models_with_predicate(std::string& onnx_model_path,std::string& predicate,
18-
std::string prefix_l,std::string prefix_r);
22+
void merge_with_model_path(std::string& mp_in_path1, std::string& mp_in_path2,
23+
std::string& mp_name1, std::string& mp_name2,
24+
std::string& mp_out_path);
25+
26+
void optimize_on_merged_model(std::string& mp_in_path,
27+
std::string& mp_out_path);
1928

20-
void merge_with_model_path(std::string& mp_in_path1,std::string& mp_in_path2,std::string& mp_name1,std::string& mp_name2,
21-
std::string& mp_out_path);
29+
std::vector<std::string> check_redundant(std::string& changed_model_path,
30+
std::string& compared_model_path);
2231

23-
void optimize_on_merged_model(std::string& mp_in_path,std::string& mp_out_path);
32+
void change_models(std::string& changed_model_path,std::string& output_model_path,
33+
std::string& changed_input_model_path,
34+
const std::vector<std::string>& output_name);
2435
#endif // ONNX_OPTIMIZER_OPTIMIZE_C_API_H
2536

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
//
2+
// Created by xyyang's mac on 2024/7/1.
3+
//
4+
#include <onnx/checker.h>
5+
#include <onnx/onnx_pb.h>
6+
#include <onnxoptimizer/optimize.h>
7+
#include <onnxoptimizer/model_util.h>
8+
#include <vector>
9+
#include "onnx/common/ir.h"
10+
#include "onnx/common/ir_pb_converter.h"
11+
#include "onnx/proto_utils.h"
12+
13+
#include "onnxoptimizer/pass_manager.h"
14+
#include "onnxoptimizer/pass_registry.h"
15+
namespace onnx::optimization {
16+
void mark_reachable_nodes(std::shared_ptr<Graph> graph, const std::string& input_name, std::unordered_set<std::string>& reachable_nodes) {
17+
for(auto node:graph->nodes()){
18+
if(reachable_nodes.find(node->name())!=reachable_nodes.end())
19+
continue;
20+
for(auto tmp_input:node->inputs()){
21+
if(tmp_input->uniqueName()==input_name){
22+
reachable_nodes.insert(node->name());
23+
for(auto output:node->outputs()){
24+
mark_reachable_nodes(graph,output->uniqueName(),reachable_nodes);
25+
}
26+
}
27+
}
28+
}
29+
}
30+
31+
std::vector<std::string> check_redundant(std::string& changed_model_path, std::string& compared_model_path){
32+
ModelProto model_changed, model_compared;
33+
onnx::optimization::loadModel(&model_changed, changed_model_path, true);
34+
onnx::optimization::loadModel(&model_compared, compared_model_path, true);
35+
36+
std::shared_ptr<Graph> g_changed(ImportModelProto(model_changed));
37+
std::shared_ptr<Graph> g_compared(ImportModelProto(model_compared));
38+
auto node_list_changed = g_changed->nodes();
39+
auto node_list_compared = g_compared->nodes();
40+
// 使用hashmap记录被比较的模型的结点
41+
std::unordered_map<Node *, Node *, CSENodeHash, CSEEqual> hash_map;
42+
for (auto it = node_list_compared.begin(); it != node_list_compared.end(); ++it) {
43+
auto node = *it;
44+
if (!node->hasUses() || !IsSupportedByCSE(node))
45+
continue;
46+
47+
if (hash_map.find(node) == hash_map.end())
48+
hash_map[node] = node;
49+
else
50+
//同一个模型内应该不会有重复的算子?
51+
continue;
52+
}
53+
std::vector<Node *> same_node_list;
54+
for(auto it = node_list_changed.begin(); it != node_list_changed.end(); ++it){
55+
auto node = *it;
56+
if (!node->hasUses() || !IsSupportedByCSE(node))
57+
continue;
58+
if (hash_map.find(node) != hash_map.end()) {
59+
same_node_list.push_back(node);
60+
}
61+
}
62+
std::vector<Node *> same_node_input_list(same_node_list);
63+
for(int i = 0;i<same_node_list.size(); i++){
64+
auto node=same_node_list[i];
65+
auto outputs = node->outputs();
66+
int tar = outputs.size();
67+
int count=0;
68+
for(int i=0;i<outputs.size();i++){
69+
bool found= false;
70+
auto output=outputs[i];
71+
for(Node* node1: same_node_input_list){
72+
if(found) break;
73+
auto inputs = node1->inputs();
74+
for(int j=0;j<inputs.size();j++){
75+
if(output->uniqueName()==inputs[j]->uniqueName()){
76+
count++;
77+
found= true;
78+
break;
79+
}
80+
}
81+
}
82+
}
83+
84+
if(count==tar) {
85+
auto it = std::remove(same_node_list.begin(), same_node_list.end(), node);
86+
same_node_list.erase(it, same_node_list.end());
87+
i--;
88+
}
89+
}
90+
91+
// 检测没有依赖的node
92+
std::vector<Node *> final_node_list;
93+
for(int i=0; i<same_node_list.size();i++){
94+
auto it = std::remove(same_node_input_list.begin(), same_node_input_list.end(), same_node_list[i]);
95+
same_node_input_list.erase(it, same_node_input_list.end());
96+
}
97+
98+
for(int i=0; i<same_node_list.size();i++){
99+
auto inputs=same_node_list[i]->inputs();
100+
bool sign=false;
101+
for(int j=0;j<inputs.size();j++){
102+
if(sign) break;
103+
auto input=inputs[j];
104+
for(int m=0; m<same_node_input_list.size(); m++){
105+
if(sign) break;
106+
auto outputs=same_node_input_list[m]->outputs();
107+
for(int n=0; n<outputs.size(); n++){
108+
if(outputs[n]->uniqueName()==input->uniqueName()){
109+
sign=true;
110+
final_node_list.push_back(same_node_list[i]);
111+
break;
112+
}
113+
}
114+
}
115+
}
116+
}
117+
118+
std::vector<std::string> value_name_list;
119+
for(int i=0;i<final_node_list.size();i++){
120+
for(auto node: final_node_list){
121+
for(auto value:node->outputs()){
122+
value_name_list.push_back(value->uniqueName());
123+
}
124+
}
125+
}
126+
return value_name_list;
127+
}
128+
129+
void change_models(std::string& changed_model_path, std::string& output_model_path, std::string& changed_input_model_path,const std::vector<std::string>& output_name){
130+
ModelProto model_changed;
131+
onnx::optimization::loadModel(&model_changed, changed_model_path, true);
132+
std::shared_ptr<Graph> g_changed(ImportModelProto(model_changed));
133+
onnx::ValueInfoProto output;
134+
onnx::TypeProto input_value_type;
135+
for(const std::string& name:output_name){
136+
bool found=false;
137+
for(auto node:g_changed->nodes()){
138+
if(found) break;
139+
for(auto tmp_output:node->outputs()){
140+
if(tmp_output->uniqueName()==name){
141+
output.set_allocated_type(tmp_output->valueType());
142+
output.set_name(name);
143+
*model_changed.mutable_graph()->add_output() = output;
144+
input_value_type.CopyFrom(output.type());
145+
found= true;
146+
break;
147+
}
148+
}
149+
}
150+
}
151+
//onnx::checker::check_model(model_changed);
152+
saveModel(&model_changed,output_model_path);
153+
154+
// below del input
155+
ModelProto model_input_changed;
156+
onnx::optimization::loadModel(&model_input_changed, changed_model_path, true);
157+
std::shared_ptr<Graph> g_input_changed(ImportModelProto(model_changed));
158+
model_input_changed.mutable_graph()->clear_input();
159+
onnx::ValueInfoProto input;
160+
input.mutable_type()->CopyFrom(input_value_type);
161+
input.set_name(output_name[0]);
162+
// 暂时固定输入形状和类型
163+
onnx::TensorShapeProto_Dimension input_dim1,input_dim2;
164+
//input_dim2.set_dim_value(1);
165+
*input.mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim()=input_dim1;
166+
*input.mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim()=input_dim2;
167+
input.mutable_type()->mutable_tensor_type()->set_elem_type(1);
168+
*model_input_changed.mutable_graph()->add_input()=input;
169+
// del nodes without input
170+
std::unordered_set<std::string> reachable_nodes;
171+
mark_reachable_nodes(g_input_changed,output_name[0],reachable_nodes);
172+
for(int i=0;i<model_input_changed.mutable_graph()->node_size();i++){
173+
if(reachable_nodes.find(model_input_changed.mutable_graph()->node(i).name())==reachable_nodes.end()){
174+
model_input_changed.mutable_graph()->mutable_node()->DeleteSubrange(i,1);
175+
i--;
176+
}
177+
}
178+
saveModel(&model_input_changed,changed_input_model_path);
179+
}
180+
181+
182+
183+
}

0 commit comments

Comments
 (0)