|
| 1 | +// |
| 2 | +// Created by xyyang's mac on 2024/7/1. |
| 3 | +// |
| 4 | +#include <onnx/checker.h> |
| 5 | +#include <onnx/onnx_pb.h> |
| 6 | +#include <onnxoptimizer/optimize.h> |
| 7 | +#include <onnxoptimizer/model_util.h> |
| 8 | +#include <vector> |
| 9 | +#include "onnx/common/ir.h" |
| 10 | +#include "onnx/common/ir_pb_converter.h" |
| 11 | +#include "onnx/proto_utils.h" |
| 12 | + |
| 13 | +#include "onnxoptimizer/pass_manager.h" |
| 14 | +#include "onnxoptimizer/pass_registry.h" |
| 15 | +namespace onnx::optimization { |
| 16 | + void mark_reachable_nodes(std::shared_ptr<Graph> graph, const std::string& input_name, std::unordered_set<std::string>& reachable_nodes) { |
| 17 | + for(auto node:graph->nodes()){ |
| 18 | + if(reachable_nodes.find(node->name())!=reachable_nodes.end()) |
| 19 | + continue; |
| 20 | + for(auto tmp_input:node->inputs()){ |
| 21 | + if(tmp_input->uniqueName()==input_name){ |
| 22 | + reachable_nodes.insert(node->name()); |
| 23 | + for(auto output:node->outputs()){ |
| 24 | + mark_reachable_nodes(graph,output->uniqueName(),reachable_nodes); |
| 25 | + } |
| 26 | + } |
| 27 | + } |
| 28 | + } |
| 29 | + } |
| 30 | + |
| 31 | + std::vector<std::string> check_redundant(std::string& changed_model_path, std::string& compared_model_path){ |
| 32 | + ModelProto model_changed, model_compared; |
| 33 | + onnx::optimization::loadModel(&model_changed, changed_model_path, true); |
| 34 | + onnx::optimization::loadModel(&model_compared, compared_model_path, true); |
| 35 | + |
| 36 | + std::shared_ptr<Graph> g_changed(ImportModelProto(model_changed)); |
| 37 | + std::shared_ptr<Graph> g_compared(ImportModelProto(model_compared)); |
| 38 | + auto node_list_changed = g_changed->nodes(); |
| 39 | + auto node_list_compared = g_compared->nodes(); |
| 40 | + // 使用hashmap记录被比较的模型的结点 |
| 41 | + std::unordered_map<Node *, Node *, CSENodeHash, CSEEqual> hash_map; |
| 42 | + for (auto it = node_list_compared.begin(); it != node_list_compared.end(); ++it) { |
| 43 | + auto node = *it; |
| 44 | + if (!node->hasUses() || !IsSupportedByCSE(node)) |
| 45 | + continue; |
| 46 | + |
| 47 | + if (hash_map.find(node) == hash_map.end()) |
| 48 | + hash_map[node] = node; |
| 49 | + else |
| 50 | + //同一个模型内应该不会有重复的算子? |
| 51 | + continue; |
| 52 | + } |
| 53 | + std::vector<Node *> same_node_list; |
| 54 | + for(auto it = node_list_changed.begin(); it != node_list_changed.end(); ++it){ |
| 55 | + auto node = *it; |
| 56 | + if (!node->hasUses() || !IsSupportedByCSE(node)) |
| 57 | + continue; |
| 58 | + if (hash_map.find(node) != hash_map.end()) { |
| 59 | + same_node_list.push_back(node); |
| 60 | + } |
| 61 | + } |
| 62 | + std::vector<Node *> same_node_input_list(same_node_list); |
| 63 | + for(int i = 0;i<same_node_list.size(); i++){ |
| 64 | + auto node=same_node_list[i]; |
| 65 | + auto outputs = node->outputs(); |
| 66 | + int tar = outputs.size(); |
| 67 | + int count=0; |
| 68 | + for(int i=0;i<outputs.size();i++){ |
| 69 | + bool found= false; |
| 70 | + auto output=outputs[i]; |
| 71 | + for(Node* node1: same_node_input_list){ |
| 72 | + if(found) break; |
| 73 | + auto inputs = node1->inputs(); |
| 74 | + for(int j=0;j<inputs.size();j++){ |
| 75 | + if(output->uniqueName()==inputs[j]->uniqueName()){ |
| 76 | + count++; |
| 77 | + found= true; |
| 78 | + break; |
| 79 | + } |
| 80 | + } |
| 81 | + } |
| 82 | + } |
| 83 | + |
| 84 | + if(count==tar) { |
| 85 | + auto it = std::remove(same_node_list.begin(), same_node_list.end(), node); |
| 86 | + same_node_list.erase(it, same_node_list.end()); |
| 87 | + i--; |
| 88 | + } |
| 89 | + } |
| 90 | + |
| 91 | + // 检测没有依赖的node |
| 92 | + std::vector<Node *> final_node_list; |
| 93 | + for(int i=0; i<same_node_list.size();i++){ |
| 94 | + auto it = std::remove(same_node_input_list.begin(), same_node_input_list.end(), same_node_list[i]); |
| 95 | + same_node_input_list.erase(it, same_node_input_list.end()); |
| 96 | + } |
| 97 | + |
| 98 | + for(int i=0; i<same_node_list.size();i++){ |
| 99 | + auto inputs=same_node_list[i]->inputs(); |
| 100 | + bool sign=false; |
| 101 | + for(int j=0;j<inputs.size();j++){ |
| 102 | + if(sign) break; |
| 103 | + auto input=inputs[j]; |
| 104 | + for(int m=0; m<same_node_input_list.size(); m++){ |
| 105 | + if(sign) break; |
| 106 | + auto outputs=same_node_input_list[m]->outputs(); |
| 107 | + for(int n=0; n<outputs.size(); n++){ |
| 108 | + if(outputs[n]->uniqueName()==input->uniqueName()){ |
| 109 | + sign=true; |
| 110 | + final_node_list.push_back(same_node_list[i]); |
| 111 | + break; |
| 112 | + } |
| 113 | + } |
| 114 | + } |
| 115 | + } |
| 116 | + } |
| 117 | + |
| 118 | + std::vector<std::string> value_name_list; |
| 119 | + for(int i=0;i<final_node_list.size();i++){ |
| 120 | + for(auto node: final_node_list){ |
| 121 | + for(auto value:node->outputs()){ |
| 122 | + value_name_list.push_back(value->uniqueName()); |
| 123 | + } |
| 124 | + } |
| 125 | + } |
| 126 | + return value_name_list; |
| 127 | + } |
| 128 | + |
| 129 | + void change_models(std::string& changed_model_path, std::string& output_model_path, std::string& changed_input_model_path,const std::vector<std::string>& output_name){ |
| 130 | + ModelProto model_changed; |
| 131 | + onnx::optimization::loadModel(&model_changed, changed_model_path, true); |
| 132 | + std::shared_ptr<Graph> g_changed(ImportModelProto(model_changed)); |
| 133 | + onnx::ValueInfoProto output; |
| 134 | + onnx::TypeProto input_value_type; |
| 135 | + for(const std::string& name:output_name){ |
| 136 | + bool found=false; |
| 137 | + for(auto node:g_changed->nodes()){ |
| 138 | + if(found) break; |
| 139 | + for(auto tmp_output:node->outputs()){ |
| 140 | + if(tmp_output->uniqueName()==name){ |
| 141 | + output.set_allocated_type(tmp_output->valueType()); |
| 142 | + output.set_name(name); |
| 143 | + *model_changed.mutable_graph()->add_output() = output; |
| 144 | + input_value_type.CopyFrom(output.type()); |
| 145 | + found= true; |
| 146 | + break; |
| 147 | + } |
| 148 | + } |
| 149 | + } |
| 150 | + } |
| 151 | + //onnx::checker::check_model(model_changed); |
| 152 | + saveModel(&model_changed,output_model_path); |
| 153 | + |
| 154 | + // below del input |
| 155 | + ModelProto model_input_changed; |
| 156 | + onnx::optimization::loadModel(&model_input_changed, changed_model_path, true); |
| 157 | + std::shared_ptr<Graph> g_input_changed(ImportModelProto(model_changed)); |
| 158 | + model_input_changed.mutable_graph()->clear_input(); |
| 159 | + onnx::ValueInfoProto input; |
| 160 | + input.mutable_type()->CopyFrom(input_value_type); |
| 161 | + input.set_name(output_name[0]); |
| 162 | + // 暂时固定输入形状和类型 |
| 163 | + onnx::TensorShapeProto_Dimension input_dim1,input_dim2; |
| 164 | + //input_dim2.set_dim_value(1); |
| 165 | + *input.mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim()=input_dim1; |
| 166 | + *input.mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim()=input_dim2; |
| 167 | + input.mutable_type()->mutable_tensor_type()->set_elem_type(1); |
| 168 | + *model_input_changed.mutable_graph()->add_input()=input; |
| 169 | + // del nodes without input |
| 170 | + std::unordered_set<std::string> reachable_nodes; |
| 171 | + mark_reachable_nodes(g_input_changed,output_name[0],reachable_nodes); |
| 172 | + for(int i=0;i<model_input_changed.mutable_graph()->node_size();i++){ |
| 173 | + if(reachable_nodes.find(model_input_changed.mutable_graph()->node(i).name())==reachable_nodes.end()){ |
| 174 | + model_input_changed.mutable_graph()->mutable_node()->DeleteSubrange(i,1); |
| 175 | + i--; |
| 176 | + } |
| 177 | + } |
| 178 | + saveModel(&model_input_changed,changed_input_model_path); |
| 179 | + } |
| 180 | + |
| 181 | + |
| 182 | + |
| 183 | +} |
0 commit comments