Skip to content

Commit d9163e3

Browse files
Weiming Zhaoweimingzha0
Weiming Zhao
authored andcommitted
[Refactor] Remove redundant code in diagnostic driver
1 parent 3e9324f commit d9163e3

File tree

1 file changed

+4
-72
lines changed

1 file changed

+4
-72
lines changed

Diff for: armory/diagnostic/driver.cc

+4-72
Original file line numberDiff line numberDiff line change
@@ -31,86 +31,17 @@
3131
#include "halo/lib/transforms/onnxextension_legalizer.h"
3232
#include "halo/lib/transforms/tfextension_legalizer.h"
3333
#include "halo/lib/transforms/type_legalizer.h"
34+
#include "halo/utils/cl_options.h"
3435
#include "llvm/Support/CommandLine.h"
3536
#include "llvm/Support/FileSystem.h"
3637
#include "llvm/Support/Path.h"
3738

3839
using namespace halo;
3940

40-
static llvm::cl::list<std::string> ModelFiles(
41-
llvm::cl::Positional, llvm::cl::desc("model file name."),
42-
llvm::cl::OneOrMore);
43-
44-
static llvm::cl::opt<Parser::Format> ModelFormat(
45-
"x",
46-
llvm::cl::desc(
47-
"format of the following input model files. Permissible formats "
48-
"include: TENSORFLOW CAFFE ONNX MXNET. If unspecified, the format is "
49-
"guessed base on file's extension."),
50-
llvm::cl::init(Parser::Format::INVALID));
51-
52-
static llvm::cl::opt<std::string> EntryFunctionName(
53-
"entry-func-name", llvm::cl::desc("name of entry function"),
54-
llvm::cl::init(""));
55-
5641
static llvm::cl::opt<bool> PrintDiagnosticReport(
5742
"print-diagnostic-report", llvm::cl::desc("Print diagnostic report"),
5843
llvm::cl::init(false));
5944

60-
/// Guess the model format based on input file extension.gg
61-
static Parser::Format InferFormat(
62-
const llvm::cl::list<std::string>& model_files, size_t file_idx) {
63-
llvm::StringRef ext = llvm::sys::path::extension(model_files[file_idx]);
64-
auto format = llvm::StringSwitch<Parser::Format>(ext)
65-
.Case(".pb", Parser::Format::TENSORFLOW)
66-
.Case(".pbtxt", Parser::Format::TENSORFLOW)
67-
.Case(".prototxt", Parser::Format::TENSORFLOW)
68-
.Case(".onnx", Parser::Format::ONNX)
69-
.Case(".json", Parser::Format::MXNET)
70-
.Default(Parser::Format::INVALID);
71-
// Check the next input file to see if it is caffe.
72-
if (format == Parser::Format::TENSORFLOW &&
73-
(file_idx + 1 < model_files.size()) &&
74-
llvm::sys::path::extension(model_files[file_idx + 1]) == ".caffemodel") {
75-
format = Parser::Format::CAFFE;
76-
}
77-
return format;
78-
}
79-
80-
static Status ParseModels(const llvm::cl::list<std::string>& model_files,
81-
const llvm::cl::opt<Parser::Format>& model_format,
82-
const llvm::cl::opt<std::string>& entry_func_name,
83-
const armory::Opts& opts, Module* module) {
84-
std::set<std::string> func_names;
85-
for (size_t i = 0, e = model_files.size(); i < e; ++i) {
86-
Parser::Format format = model_format;
87-
if (format == Parser::Format::INVALID) {
88-
format = InferFormat(model_files, i);
89-
}
90-
91-
FunctionBuilder func_builder(module);
92-
// Use stem of the input model as function name.
93-
std::string func_name = entry_func_name.empty()
94-
? llvm::sys::path::stem(model_files[i]).str()
95-
: entry_func_name.getValue();
96-
while (func_names.count(func_name) != 0) {
97-
func_name.append("_").append(std::to_string(i));
98-
}
99-
func_names.insert(func_name);
100-
Function* func = func_builder.CreateFunction(func_name);
101-
std::vector<std::string> files{model_files[i]};
102-
if (format == Parser::Format::CAFFE || format == Parser::Format::MXNET) {
103-
HLCHECK(i + 1 < e);
104-
files.push_back(model_files[++i]);
105-
}
106-
if (Status status = Parser::Parse(func, format, files, opts);
107-
status != Status::SUCCESS) {
108-
return status;
109-
}
110-
}
111-
return Status::SUCCESS;
112-
}
113-
11445
int main(int argc, char** argv) {
11546
llvm::cl::ParseCommandLineOptions(argc, argv);
11647
GlobalContext ctx;
@@ -119,8 +50,9 @@ int main(int argc, char** argv) {
11950
Module m(ctx, "diagnostic_module");
12051

12152
armory::Opts opts(PrintDiagnosticReport);
122-
if (ParseModels(ModelFiles, ModelFormat, EntryFunctionName, opts, &m) !=
123-
Status::SUCCESS) {
53+
Parser::Format format = Parser::Format::INVALID;
54+
if (ParseModels(ModelFiles, ModelFormat, EntryFunctionName, opts, &m,
55+
&format) != Status::SUCCESS) {
12456
return 1;
12557
}
12658

0 commit comments

Comments
 (0)