31
31
#include " halo/lib/transforms/onnxextension_legalizer.h"
32
32
#include " halo/lib/transforms/tfextension_legalizer.h"
33
33
#include " halo/lib/transforms/type_legalizer.h"
34
+ #include " halo/utils/cl_options.h"
34
35
#include " llvm/Support/CommandLine.h"
35
36
#include " llvm/Support/FileSystem.h"
36
37
#include " llvm/Support/Path.h"
37
38
38
39
using namespace halo ;
39
40
40
- static llvm::cl::list<std::string> ModelFiles (
41
- llvm::cl::Positional, llvm::cl::desc(" model file name." ),
42
- llvm::cl::OneOrMore);
43
-
44
- static llvm::cl::opt<Parser::Format> ModelFormat (
45
- " x" ,
46
- llvm::cl::desc (
47
- " format of the following input model files. Permissible formats "
48
- " include: TENSORFLOW CAFFE ONNX MXNET. If unspecified, the format is "
49
- " guessed base on file's extension." ),
50
- llvm::cl::init(Parser::Format::INVALID));
51
-
52
- static llvm::cl::opt<std::string> EntryFunctionName (
53
- " entry-func-name" , llvm::cl::desc(" name of entry function" ),
54
- llvm::cl::init(" " ));
55
-
56
41
static llvm::cl::opt<bool > PrintDiagnosticReport (
57
42
" print-diagnostic-report" , llvm::cl::desc(" Print diagnostic report" ),
58
43
llvm::cl::init(false ));
59
44
60
- // / Guess the model format based on input file extension.gg
61
- static Parser::Format InferFormat (
62
- const llvm::cl::list<std::string>& model_files, size_t file_idx) {
63
- llvm::StringRef ext = llvm::sys::path::extension (model_files[file_idx]);
64
- auto format = llvm::StringSwitch<Parser::Format>(ext)
65
- .Case (" .pb" , Parser::Format::TENSORFLOW)
66
- .Case (" .pbtxt" , Parser::Format::TENSORFLOW)
67
- .Case (" .prototxt" , Parser::Format::TENSORFLOW)
68
- .Case (" .onnx" , Parser::Format::ONNX)
69
- .Case (" .json" , Parser::Format::MXNET)
70
- .Default (Parser::Format::INVALID);
71
- // Check the next input file to see if it is caffe.
72
- if (format == Parser::Format::TENSORFLOW &&
73
- (file_idx + 1 < model_files.size ()) &&
74
- llvm::sys::path::extension (model_files[file_idx + 1 ]) == " .caffemodel" ) {
75
- format = Parser::Format::CAFFE;
76
- }
77
- return format;
78
- }
79
-
80
- static Status ParseModels (const llvm::cl::list<std::string>& model_files,
81
- const llvm::cl::opt<Parser::Format>& model_format,
82
- const llvm::cl::opt<std::string>& entry_func_name,
83
- const armory::Opts& opts, Module* module) {
84
- std::set<std::string> func_names;
85
- for (size_t i = 0 , e = model_files.size (); i < e; ++i) {
86
- Parser::Format format = model_format;
87
- if (format == Parser::Format::INVALID) {
88
- format = InferFormat (model_files, i);
89
- }
90
-
91
- FunctionBuilder func_builder (module);
92
- // Use stem of the input model as function name.
93
- std::string func_name = entry_func_name.empty ()
94
- ? llvm::sys::path::stem (model_files[i]).str ()
95
- : entry_func_name.getValue ();
96
- while (func_names.count (func_name) != 0 ) {
97
- func_name.append (" _" ).append (std::to_string (i));
98
- }
99
- func_names.insert (func_name);
100
- Function* func = func_builder.CreateFunction (func_name);
101
- std::vector<std::string> files{model_files[i]};
102
- if (format == Parser::Format::CAFFE || format == Parser::Format::MXNET) {
103
- HLCHECK (i + 1 < e);
104
- files.push_back (model_files[++i]);
105
- }
106
- if (Status status = Parser::Parse (func, format, files, opts);
107
- status != Status::SUCCESS) {
108
- return status;
109
- }
110
- }
111
- return Status::SUCCESS;
112
- }
113
-
114
45
int main (int argc, char ** argv) {
115
46
llvm::cl::ParseCommandLineOptions (argc, argv);
116
47
GlobalContext ctx;
@@ -119,8 +50,9 @@ int main(int argc, char** argv) {
119
50
Module m (ctx, " diagnostic_module" );
120
51
121
52
armory::Opts opts (PrintDiagnosticReport);
122
- if (ParseModels (ModelFiles, ModelFormat, EntryFunctionName, opts, &m) !=
123
- Status::SUCCESS) {
53
+ Parser::Format format = Parser::Format::INVALID;
54
+ if (ParseModels (ModelFiles, ModelFormat, EntryFunctionName, opts, &m,
55
+ &format) != Status::SUCCESS) {
124
56
return 1 ;
125
57
}
126
58
0 commit comments