@@ -35,39 +35,48 @@ static bool ValidateFiles(const std::vector<std::string>& file_list) {
35
35
return true ;
36
36
}
37
37
38
- Status Parser::Parse (Function* function, Format format,
39
- const std::string& variant,
40
- const std::vector<std::string>& file_list,
41
- const armory::Opts& opts) {
42
- if (!ValidateFiles (file_list)) {
43
- return Status::FILE_NOT_EXIST;
44
- }
45
-
38
+ static std::unique_ptr<Parser> GetParser (Parser::Format format,
39
+ const std::string& variant,
40
+ const armory::Opts& opts) {
46
41
std::unique_ptr<Parser> parser (nullptr );
47
42
switch (format) {
48
- case Format::TENSORFLOW: {
43
+ case Parser:: Format::TENSORFLOW: {
49
44
if (opts.convert_to_ipu_graphdef ) {
50
45
parser = CreateIPUParser (variant);
51
46
} else {
52
47
parser = CreateTFParser (variant);
53
48
}
54
49
break ;
55
50
}
56
- case Format::ONNX: {
51
+ case Parser:: Format::ONNX: {
57
52
parser = CreateONNXParser ();
58
53
break ;
59
54
}
60
- case Format::TFLITE: {
55
+ case Parser:: Format::TFLITE: {
61
56
parser = CreateTFLITEParser ();
62
57
break ;
63
58
}
64
- case Format::CAFFE: {
59
+ case Parser:: Format::CAFFE: {
65
60
parser = CreateCAFFEParser ();
66
61
break ;
67
62
}
68
63
default :
69
64
HLCHECK (0 && " Unsupported format" );
70
65
}
66
+ return parser;
67
+ }
68
+
69
+ Status Parser::Parse (Function* function, Format format,
70
+ const std::string& variant,
71
+ const std::vector<std::string>& file_list,
72
+ const armory::Opts& opts) {
73
+ if (!ValidateFiles (file_list)) {
74
+ return Status::FILE_NOT_EXIST;
75
+ }
76
+ auto parser = GetParser (format, variant, opts);
77
+ if (parser == nullptr ) {
78
+ return Status::ILLEGAL_PARAM;
79
+ }
71
80
return parser->Parse (function, file_list, opts);
72
81
}
73
82
@@ -77,4 +86,27 @@ Status Parser::Parse(Function* function, Format format,
77
86
return Parse (function, format, " " , file_list, opts);
78
87
}
79
88
89
+ Status Parser::Parse (Function* function,
90
+ const std::vector<const char *>& buffers,
91
+ const std::vector<size_t >& buffer_sizes, Format format) {
92
+ armory::Opts opts;
93
+ std::string variant;
94
+ auto parser = GetParser (format, variant, opts);
95
+ if (parser == nullptr ) {
96
+ return Status::ILLEGAL_PARAM;
97
+ }
98
+ return parser->Parse (function, buffers, buffer_sizes);
99
+ }
100
+
101
+ Status Parser::Parse (Function* function, const std::vector<const void *>& model,
102
+ Format format) {
103
+ armory::Opts opts;
104
+ std::string variant;
105
+ auto parser = GetParser (format, variant, opts);
106
+ if (parser == nullptr ) {
107
+ return Status::ILLEGAL_PARAM;
108
+ }
109
+ return parser->Parse (function, model);
110
+ }
111
+
80
112
} // namespace halo
0 commit comments