@@ -54,25 +54,24 @@ void add_prefix_on_model(std::string& changed_model_path, std::string& output_mo
54
54
onnx::optimization::add_prefix_on_model (changed_model_path, output_model_path, prefix);
55
55
}
56
56
57
- // / @brief
58
- // / @param input_model_path
59
- // / @param comparison_operator 1: ==, 2: <, 3: <=, 4: >, 5: >=
60
- // / @param threshold
61
- // / @param features
62
- // / @return optimized-model path
63
57
std::string optimize_on_decision_tree_predicate (std::string& input_model_path, uint8_t comparison_operator,
64
- float threshold) {
58
+ float threshold, int threads_count ) {
65
59
std::string mp1 = onnx::optimization::DTConvertRule::match (input_model_path);
66
- std::string mp2 = onnx::optimization::DTPruneRule::match (mp1, comparison_operator, threshold);
67
- return onnx::optimization::DTMergeRule::match (mp2);
60
+ std::string mp2 = onnx::optimization::DTPruneRule::match (mp1, comparison_operator, threshold, threads_count);
61
+ return onnx::optimization::DTMergeRule::match (mp2, threads_count);
62
+ }
68
63
69
- // std::string mp1 = onnx::optimization::DTPruneRule::match(input_model_path, comparison_operator, threshold);
70
- // return onnx::optimization::DTMergeRule::match(mp1);
71
- // return onnx::optimization::DTMergeRule::match(mp3);
72
- // return onnx::optimization::DTConvertRule::match(input_model_path);
73
- // return onnx::optimization::DTPruneRule::match(input_model_path, comparison_operator, threshold);
64
+ std::string optimize_on_decision_tree_predicate_convert (std::string& input_model_path){
65
+ return onnx::optimization::DTConvertRule::match (input_model_path);
66
+ }
74
67
75
- // return onnx::optimization::DTConvertRule::match(input_model_path);
68
+ std::string optimize_on_decision_tree_predicate_prune (std::string& input_model_path, uint8_t comparison_operator, float threshold, int threads_count){
69
+ return onnx::optimization::DTPruneRule::match (input_model_path, comparison_operator, threshold, threads_count);
76
70
}
77
71
72
+ std::string optimize_on_decision_tree_predicate_merge (std::string& input_model_path, int threads_count){
73
+ return onnx::optimization::DTMergeRule::match (input_model_path, threads_count);
74
+ }
75
+
76
+
78
77
// -----------------------
0 commit comments