Skip to content

Commit d5e835c

Browse files
Merge pull request #3 from ddding-z/master
ReTree Random Forest Implementation
2 parents b7673a9 + 85cdfea commit d5e835c

File tree

7 files changed

+931
-730
lines changed

7 files changed

+931
-730
lines changed

Diff for: .gitmodules

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33
url = https://github.com/protocolbuffers/protobuf.git
44
[submodule "third_party/onnx"]
55
path = third_party/onnx
6-
url = git@github.com:lovelynewlife/onnx.git
6+
url = https://github.com/lovelynewlife/onnx.git

Diff for: CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
cmake_minimum_required(VERSION 3.22)
1+
cmake_minimum_required(VERSION 3.16)
22

33
# For std::filesystem
44
# Must be a cache variable and be set before project()

Diff for: build.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@ rm -rf build
77
cmake -B build -DCMAKE_BUILD_TYPE=Debug -DCMAKE_EXPORT_COMPILE_COMMANDS=ON
88

99
# 构建项目,Debug 模式,先清理再并行编译
10-
cmake --build build --config Debug --clean-first --parallel 10
10+
cmake --build build --config Debug --clean-first --parallel 80

Diff for: onnxoptimizer/optimize_c_api/optimize_c_api.cpp

+14-15
Original file line numberDiff line numberDiff line change
@@ -54,25 +54,24 @@ void add_prefix_on_model(std::string& changed_model_path, std::string& output_mo
5454
onnx::optimization::add_prefix_on_model(changed_model_path, output_model_path, prefix);
5555
}
5656

57-
/// @brief
58-
/// @param input_model_path
59-
/// @param comparison_operator 1: ==, 2: <, 3: <=, 4: >, 5: >=
60-
/// @param threshold
61-
/// @param features
62-
/// @return optimized-model path
6357
std::string optimize_on_decision_tree_predicate(std::string& input_model_path, uint8_t comparison_operator,
64-
float threshold) {
58+
float threshold, int threads_count) {
6559
std::string mp1 = onnx::optimization::DTConvertRule::match(input_model_path);
66-
std::string mp2 = onnx::optimization::DTPruneRule::match(mp1, comparison_operator, threshold);
67-
return onnx::optimization::DTMergeRule::match(mp2);
60+
std::string mp2 = onnx::optimization::DTPruneRule::match(mp1, comparison_operator, threshold, threads_count);
61+
return onnx::optimization::DTMergeRule::match(mp2, threads_count);
62+
}
6863

69-
// std::string mp1 = onnx::optimization::DTPruneRule::match(input_model_path, comparison_operator, threshold);
70-
// return onnx::optimization::DTMergeRule::match(mp1);
71-
// return onnx::optimization::DTMergeRule::match(mp3);
72-
// return onnx::optimization::DTConvertRule::match(input_model_path);
73-
// return onnx::optimization::DTPruneRule::match(input_model_path, comparison_operator, threshold);
64+
std::string optimize_on_decision_tree_predicate_convert(std::string& input_model_path){
65+
return onnx::optimization::DTConvertRule::match(input_model_path);
66+
}
7467

75-
// return onnx::optimization::DTConvertRule::match(input_model_path);
68+
std::string optimize_on_decision_tree_predicate_prune(std::string& input_model_path, uint8_t comparison_operator, float threshold, int threads_count){
69+
return onnx::optimization::DTPruneRule::match(input_model_path, comparison_operator, threshold, threads_count);
7670
}
7771

72+
std::string optimize_on_decision_tree_predicate_merge(std::string& input_model_path, int threads_count){
73+
return onnx::optimization::DTMergeRule::match(input_model_path, threads_count);
74+
}
75+
76+
7877
// -----------------------

Diff for: onnxoptimizer/optimize_c_api/optimize_c_api.h

+17-2
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,23 @@ void change_models(std::string& changed_model_path,std::string& output_model_pat
3434
const std::vector<std::string>& output_name);
3535
void add_prefix_on_model(std::string& changed_model_path, std::string& output_model_path, std::string& prefix);
3636

37-
//--------------------------
38-
std::string optimize_on_decision_tree_predicate(std::string& input_model_path, uint8_t comparison_operator, float threshold);
37+
/**
38+
* @brief retree optimization
39+
* 1. convert
40+
* 2. prune
41+
* 3. merge
42+
* @param input_model_path
43+
* @param comparison_operator 1: ==, 2: <, 3: <=, 4: >, 5: >=
44+
* @param threshold
45+
* @param threads_count default=1
46+
* @return optimized model path
47+
*/
48+
std::string optimize_on_decision_tree_predicate(std::string& input_model_path, uint8_t comparison_operator, float threshold, int threads_count = 1);
49+
50+
std::string optimize_on_decision_tree_predicate_convert(std::string& input_model_path);
51+
std::string optimize_on_decision_tree_predicate_prune(std::string& input_model_path, uint8_t comparison_operator, float threshold, int threads_count = 1);
52+
std::string optimize_on_decision_tree_predicate_merge(std::string& input_model_path, int threads_count = 1);
53+
3954

4055
#endif // ONNX_OPTIMIZER_OPTIMIZE_C_API_H
4156

0 commit comments

Comments
 (0)