Skip to content

Commit d2d12a1

Browse files
committed
feat: add retree rules
2 parents fc7e154 + d663544 commit d2d12a1

File tree

12 files changed

+1384
-0
lines changed

12 files changed

+1384
-0
lines changed

Diff for: .gitignore

+4
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,7 @@ compile_commands.json
8282
.mypy_cache
8383
virtualenv
8484
venv
85+
86+
examples/model4test
87+
build.sh
88+
run.sh

Diff for: CMakeLists.txt

+3
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@ target_link_libraries(predicate_example optimize_c_api)
8585
onnxopt_add_executable(redundant_example examples/optimizer_c_example/redundant.cpp)
8686
target_link_libraries(redundant_example optimize_c_api)
8787

88+
onnxopt_add_executable(optimize_dt_example examples/optimizer_c_example/optimize_dt_example.cpp)
89+
target_link_libraries(optimize_dt_example optimize_c_api)
90+
8891
#configure_file(onnxoptimizer/value_info/int_value_info.onnx
8992
# ${CMAKE_BINARY_DIR}/int_value_info.onnx COPYONLY)
9093
#if(BUILD_ONNX_PYTHON)

Diff for: build.sh

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#!/bin/bash
2+
3+
# 删除旧的 build 目录
4+
rm -rf build
5+
6+
# 创建新的 build 目录并配置 Debug 构建
7+
cmake -B build -DCMAKE_BUILD_TYPE=Debug -DCMAKE_EXPORT_COMPILE_COMMANDS=ON
8+
9+
# 构建项目,Debug 模式,先清理再并行编译
10+
cmake --build build --config Debug --clean-first --parallel 10

Diff for: examples/model4test/clf2regtest.onnx

19.5 KB
Binary file not shown.

Diff for: examples/model4test/convert_test.onnx

1.76 KB
Binary file not shown.

Diff for: examples/model4test/convert_test_reg.onnx

1.35 KB
Binary file not shown.

Diff for: examples/model4test/convert_test_test.onnx

2.6 KB
Binary file not shown.

Diff for: examples/optimizer_c_example/optimize_dt_example.cpp

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
//
2+
// Created by xyyang on 23-11-14.
3+
//
4+
//#include <onnxoptimizer/model_util.h>
5+
6+
#include <string>
7+
#include <vector>
8+
#include <iostream>
9+
#include "onnxoptimizer/optimize_c_api/optimize_c_api.h"
10+
11+
// 1: ==, 2: <, 3: <=, 4: >, 5: >=
12+
int main(int argc, char* argv[]) {
13+
std::string path1 = "../examples/onnx_input_model/titanic_pipeline.onnx";
14+
std::string path2 = "../examples/onnx_input_model/iris_dataset_pipeline.onnx";
15+
std::string path3 = "/home/ding/duckdb_project/onnx_optimizer_C/examples/onnx_input_model/house_16H_d10_l281_n561_20240922063836.onnx";
16+
std::string testmodelpath = "/home/ding/duckdb_project/onnx_optimizer_C/examples/model4test/convert_test.onnx";
17+
std::string testmodelpath1 = "/home/ding/duckdb_project/onnx_optimizer_C/examples/model4test/clf2regtest.onnx";
18+
std::string testmodelpath2 = "/home/ding/duckdb_project/onnx_optimizer_C/examples/model4test/wine_quality_d11_l280_n559_20241209164224_with_zipmap.onnx";
19+
std::string testmodelpath3 = "/home/ding/duckdb_project/onnx_optimizer_C/examples/model4test/wine_quality_d11_l280_n559_20241209164224.onnx";
20+
// ?
21+
// std::vector<std::string> features = {};
22+
// std::string path3 = "../examples/onnx_input_model/house_16H_d10_l281_n561_20240922063836.onnx";
23+
// optimize_on_decision_tree_predicate(path1, 2, 10);
24+
// optimize_on_decision_tree_predicate(path2, 2, 10);
25+
// std::cout << optimize_on_decision_tree_predicate(path3, 4, 10, &features) << std::endl;
26+
std::cout << optimize_on_decision_tree_predicate(testmodelpath3, 1, 1) << std::endl;
27+
// std::cout << optimize_on_decision_tree_predicate(testmodelpath, 1, 1) << std::endl;
28+
return 0;
29+
}

Diff for: onnxoptimizer/optimize_c_api/optimize_c_api.cpp

+24
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
#include "optimize_c_api.h"
88
//#include "onnxoptimizer/query_c_api/model_merge.cpp"
9+
#include "onnxoptimizer/query_c_api/decision_tree_predicate.cpp"
910
#include "onnxoptimizer/query_c_api/predicate_push_down.cpp"
1011
#include "onnxoptimizer/query_c_api/redundant_calculation_detection.cpp"
1112

@@ -52,3 +53,26 @@ void change_models(std::string& changed_model_path,std::string& output_model_pat
5253
void add_prefix_on_model(std::string& changed_model_path, std::string& output_model_path, std::string& prefix){
5354
onnx::optimization::add_prefix_on_model(changed_model_path, output_model_path, prefix);
5455
}
56+
57+
/// @brief
58+
/// @param input_model_path
59+
/// @param comparison_operator 1: ==, 2: <, 3: <=, 4: >, 5: >=
60+
/// @param threshold
61+
/// @param features
62+
/// @return optimized-model path
63+
std::string optimize_on_decision_tree_predicate(std::string& input_model_path, uint8_t comparison_operator,
64+
float threshold) {
65+
std::string mp1 = onnx::optimization::DTConvertRule::match(input_model_path);
66+
std::string mp2 = onnx::optimization::DTPruneRule::match(mp1, comparison_operator, threshold);
67+
return onnx::optimization::DTMergeRule::match(mp2);
68+
69+
// std::string mp1 = onnx::optimization::DTPruneRule::match(input_model_path, comparison_operator, threshold);
70+
// return onnx::optimization::DTMergeRule::match(mp1);
71+
// return onnx::optimization::DTMergeRule::match(mp3);
72+
// return onnx::optimization::DTConvertRule::match(input_model_path);
73+
// return onnx::optimization::DTPruneRule::match(input_model_path, comparison_operator, threshold);
74+
75+
// return onnx::optimization::DTConvertRule::match(input_model_path);
76+
}
77+
78+
// -----------------------

Diff for: onnxoptimizer/optimize_c_api/optimize_c_api.h

+4
Original file line numberDiff line numberDiff line change
@@ -33,5 +33,9 @@ void change_models(std::string& changed_model_path,std::string& output_model_pat
3333
std::string& changed_input_model_path,
3434
const std::vector<std::string>& output_name);
3535
void add_prefix_on_model(std::string& changed_model_path, std::string& output_model_path, std::string& prefix);
36+
37+
//--------------------------
38+
std::string optimize_on_decision_tree_predicate(std::string& input_model_path, uint8_t comparison_operator, float threshold);
39+
3640
#endif // ONNX_OPTIMIZER_OPTIMIZE_C_API_H
3741

0 commit comments

Comments
 (0)