Skip to content

Commit a74a554

Browse files
Weiming Zhaoweimingzha0
Weiming Zhao
authored andcommitted
[Parser] Allow parsing from memory data
1 parent c0df923 commit a74a554

File tree

10 files changed

+180
-15
lines changed

10 files changed

+180
-15
lines changed

include/halo/lib/parser/parser.h

+15-1
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,22 @@ class Parser {
4848
virtual Status Parse(Function* function,
4949
const std::vector<std::string>& file_list,
5050
const armory::Opts& opts) = 0;
51+
virtual Status Parse(Function* function,
52+
const std::vector<const char*>& buffers,
53+
const std::vector<size_t>& buffer_sizes) = 0;
54+
virtual Status Parse(Function* function,
55+
const std::vector<const void*>& model_defs) = 0;
56+
57+
/// Parse a model from specified data structure (e.g. graphdef).
58+
static Status Parse(Function* function, const std::vector<const void*>& model,
59+
Format format);
60+
61+
/// Parse a model from buffers based on specified format.
62+
static Status Parse(Function* function,
63+
const std::vector<const char*>& buffers,
64+
const std::vector<size_t>& buffer_sizes, Format format);
5165

52-
/// Parse a file from `file_lists` based on specified format. `variant`
66+
/// Parse a model from `file_lists` based on specified format. `variant`
5367
/// specifies sub variants like version etc., which can be empty.
5468
static Status Parse(Function* function, Format format,
5569
const std::string& variant,

lib/parser/caffe/caffe_parser.cc

+21
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,27 @@ CAFFEAttrs::CAFFEAttrs(const caffe::BlobShape& shape) {
3737

3838
CAFFEParser::~CAFFEParser() {}
3939

40+
Status CAFFEParser::Parse(Function* function,
41+
const std::vector<const char*>& buffers,
42+
const std::vector<size_t>& buffer_sizes) {
43+
return Status::ASSERTION;
44+
}
45+
46+
Status CAFFEParser::Parse(Function* function,
47+
const std::vector<const void*>& model_defs) {
48+
if (model_defs.size() < 2) {
49+
return Status::FILE_NOT_EXIST;
50+
}
51+
armory::Opts opts;
52+
BasicBlockBuilder bb_builder(function);
53+
BasicBlock* bb = bb_builder.CreateBasicBlock("bb0");
54+
const caffe::NetParameter* net_param =
55+
reinterpret_cast<const caffe::NetParameter*>(model_defs[0]);
56+
const caffe::NetParameter* net_param_weight =
57+
reinterpret_cast<const caffe::NetParameter*>(model_defs[0]);
58+
return Parse(bb, *net_param, *net_param_weight, opts);
59+
}
60+
4061
Status CAFFEParser::Parse(Function* function,
4162
const std::vector<std::string>& file_list,
4263
const armory::Opts& opts) {

lib/parser/caffe/caffe_parser.h

+5
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@ class CAFFEParser : public Parser {
5858
Status Parse(BasicBlock* bb, const caffe::NetParameter& layer_param,
5959
const caffe::NetParameter& layer_param_weight,
6060
const armory::Opts& opts);
61+
Status Parse(Function* function, const std::vector<const char*>& buffers,
62+
const std::vector<size_t>& buffer_sizes) override;
63+
Status Parse(Function* function,
64+
const std::vector<const void*>& model_defs) override;
65+
6166
~CAFFEParser();
6267

6368
static void WriteCSVReport(const caffe::LayerParameter& layer_param,

lib/parser/onnx/onnx_parser.cc

+21
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "halo/lib/framework/common.h"
2626
#include "halo/lib/framework/type.h"
2727
#include "halo/lib/ir/extension_instructions.h"
28+
#include "halo/lib/parser/parser.h"
2829
#include "onnx.pb.h"
2930

3031
namespace halo {
@@ -56,6 +57,26 @@ void ONNXParser::Scope::Insert(const std::string& name, const Value& def) {
5657
inst_name_to_ptr_[name] = def;
5758
}
5859

60+
Status ONNXParser::Parse(Function* function,
61+
const std::vector<const char*>& buffers,
62+
const std::vector<size_t>& buffer_sizes) {
63+
return Status::ASSERTION;
64+
}
65+
66+
Status ONNXParser::Parse(Function* function,
67+
const std::vector<const void*>& model_defs) {
68+
if (model_defs.empty() || model_defs[0] == nullptr) {
69+
return Status::FILE_NOT_EXIST;
70+
}
71+
72+
const onnx::GraphProto* graph_def =
73+
reinterpret_cast<const onnx::GraphProto*>(model_defs[0]);
74+
bb_builder_ = std::make_unique<BasicBlockBuilder>(function);
75+
BasicBlock* bb = bb_builder_->CreateBasicBlock("bb0");
76+
armory::Opts opts;
77+
return Parse(bb, *graph_def, opts);
78+
}
79+
5980
Status ONNXParser::Parse(Function* function,
6081
const std::vector<std::string>& file_list,
6182
const armory::Opts& opts) {

lib/parser/onnx/onnx_parser.h

+5
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,13 @@ class ONNXParser : public Parser {
7777
explicit ONNXParser(){};
7878
Status Parse(Function* function, const std::vector<std::string>& file_list,
7979
const armory::Opts& opts) override;
80+
Status Parse(Function* function, const std::vector<const char*>& buffers,
81+
const std::vector<size_t>& buffer_sizes) override;
82+
Status Parse(Function* function,
83+
const std::vector<const void*>& model_defs) override;
8084
Status Parse(BasicBlock* bb, const onnx::GraphProto& graph_def,
8185
const armory::Opts& opts);
86+
8287
~ONNXParser();
8388

8489
template <typename T>

lib/parser/parser.cc

+44-12
Original file line numberDiff line numberDiff line change
@@ -35,39 +35,48 @@ static bool ValidateFiles(const std::vector<std::string>& file_list) {
3535
return true;
3636
}
3737

38-
Status Parser::Parse(Function* function, Format format,
39-
const std::string& variant,
40-
const std::vector<std::string>& file_list,
41-
const armory::Opts& opts) {
42-
if (!ValidateFiles(file_list)) {
43-
return Status::FILE_NOT_EXIST;
44-
}
45-
38+
static std::unique_ptr<Parser> GetParser(Parser::Format format,
39+
const std::string& variant,
40+
const armory::Opts& opts) {
4641
std::unique_ptr<Parser> parser(nullptr);
4742
switch (format) {
48-
case Format::TENSORFLOW: {
43+
case Parser::Format::TENSORFLOW: {
4944
if (opts.convert_to_ipu_graphdef) {
5045
parser = CreateIPUParser(variant);
5146
} else {
5247
parser = CreateTFParser(variant);
5348
}
5449
break;
5550
}
56-
case Format::ONNX: {
51+
case Parser::Format::ONNX: {
5752
parser = CreateONNXParser();
5853
break;
5954
}
60-
case Format::TFLITE: {
55+
case Parser::Format::TFLITE: {
6156
parser = CreateTFLITEParser();
6257
break;
6358
}
64-
case Format::CAFFE: {
59+
case Parser::Format::CAFFE: {
6560
parser = CreateCAFFEParser();
6661
break;
6762
}
6863
default:
6964
HLCHECK(0 && "Unsupported format");
7065
}
66+
return parser;
67+
}
68+
69+
Status Parser::Parse(Function* function, Format format,
70+
const std::string& variant,
71+
const std::vector<std::string>& file_list,
72+
const armory::Opts& opts) {
73+
if (!ValidateFiles(file_list)) {
74+
return Status::FILE_NOT_EXIST;
75+
}
76+
auto parser = GetParser(format, variant, opts);
77+
if (parser == nullptr) {
78+
return Status::ILLEGAL_PARAM;
79+
}
7180
return parser->Parse(function, file_list, opts);
7281
}
7382

@@ -77,4 +86,27 @@ Status Parser::Parse(Function* function, Format format,
7786
return Parse(function, format, "", file_list, opts);
7887
}
7988

89+
Status Parser::Parse(Function* function,
90+
const std::vector<const char*>& buffers,
91+
const std::vector<size_t>& buffer_sizes, Format format) {
92+
armory::Opts opts;
93+
std::string variant;
94+
auto parser = GetParser(format, variant, opts);
95+
if (parser == nullptr) {
96+
return Status::ILLEGAL_PARAM;
97+
}
98+
return parser->Parse(function, buffers, buffer_sizes);
99+
}
100+
101+
Status Parser::Parse(Function* function, const std::vector<const void*>& model,
102+
Format format) {
103+
armory::Opts opts;
104+
std::string variant;
105+
auto parser = GetParser(format, variant, opts);
106+
if (parser == nullptr) {
107+
return Status::ILLEGAL_PARAM;
108+
}
109+
return parser->Parse(function, model);
110+
}
111+
80112
} // namespace halo

lib/parser/tensorflow/tf_parser.cc

+37
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717

1818
#include "tf_parser.h"
1919

20+
#include <google/protobuf/io/coded_stream.h>
2021
#include <google/protobuf/io/zero_copy_stream_impl.h>
22+
#include <google/protobuf/io/zero_copy_stream_impl_lite.h>
2123
#include <google/protobuf/text_format.h>
2224

2325
#include <fstream>
@@ -61,6 +63,41 @@ Status TFParser::Parse(Function* function,
6163
return Parse(bb, graph_def, opts);
6264
}
6365

66+
Status TFParser::Parse(Function* function,
67+
const std::vector<const char*>& buffers,
68+
const std::vector<size_t>& buffer_sizes) {
69+
GOOGLE_PROTOBUF_VERIFY_VERSION;
70+
71+
tensorflow::GraphDef graph_def;
72+
73+
google::protobuf::io::ArrayInputStream ais(buffers[0], buffer_sizes[0]);
74+
75+
if (!graph_def.ParseFromZeroCopyStream(&ais)) {
76+
if (!google::protobuf::TextFormat::Parse(&ais, &graph_def)) {
77+
LOG(ERROR) << "Encountered error(s) when parsing memory graph";
78+
return Status::ASSERTION;
79+
}
80+
}
81+
82+
BasicBlockBuilder bb_builder(function);
83+
BasicBlock* bb = bb_builder.CreateBasicBlock("bb0");
84+
return Parse(bb, graph_def, opts_);
85+
}
86+
87+
Status TFParser::Parse(Function* function,
88+
const std::vector<const void*>& model_defs) {
89+
if (model_defs.empty() || model_defs[0] == nullptr) {
90+
return Status::FILE_NOT_EXIST;
91+
}
92+
BasicBlockBuilder bb_builder(function);
93+
BasicBlock* bb = bb_builder.CreateBasicBlock("bb0");
94+
95+
const tensorflow::GraphDef* graph_def =
96+
reinterpret_cast<const tensorflow::GraphDef*>(model_defs[0]);
97+
98+
return Parse(bb, *graph_def, opts_);
99+
}
100+
64101
Status TFParser::Parse(BasicBlock* bb, const tensorflow::GraphDef& graph_def,
65102
const armory::Opts& opts) {
66103
Init(bb, bb->GetParent(), opts);

lib/parser/tensorflow/tf_parser.h

+9-2
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,11 @@ class TFParser : public Parser {
6060
explicit TFParser(const std::string& variant) : variant_(variant) {}
6161
Status Parse(Function* function, const std::vector<std::string>& file_list,
6262
const armory::Opts& opts) override;
63-
virtual Status Parse(BasicBlock* bb, const tensorflow::GraphDef& graph_def,
64-
const armory::Opts& opts);
63+
Status Parse(Function* function, const std::vector<const char*>& buffers,
64+
const std::vector<size_t>& buffer_sizes) override;
65+
Status Parse(Function* function,
66+
const std::vector<const void*>& model_defs) override;
67+
6568
~TFParser();
6669

6770
static std::vector<int64_t> ProcessShape(
@@ -73,6 +76,10 @@ class TFParser : public Parser {
7376
TFParser(const TFParser&) = delete;
7477
TFParser& operator=(const TFParser&) = delete;
7578

79+
protected:
80+
virtual Status Parse(BasicBlock* bb, const tensorflow::GraphDef& graph_def,
81+
const armory::Opts& opts);
82+
7683
private:
7784
void Init(BasicBlock* bb, Function* function, const armory::Opts& opts);
7885
void RegisterOp();

lib/parser/tflite/tflite_parser.cc

+19
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,25 @@ Status TFLITEParser::Parse(Function* function,
197197
return Parse(bb, model);
198198
}
199199

200+
Status TFLITEParser::Parse(Function* function,
201+
const std::vector<const void*>& model_defs) {
202+
if (model_defs.empty() || model_defs[0] == nullptr) {
203+
return Status::FILE_NOT_EXIST;
204+
}
205+
206+
const tflite::Model* graph_def =
207+
reinterpret_cast<const tflite::Model*>(model_defs[0]);
208+
auto bb_builder = std::make_unique<BasicBlockBuilder>(function);
209+
BasicBlock* bb = bb_builder->CreateBasicBlock("bb0");
210+
return Parse(bb, *graph_def);
211+
}
212+
213+
Status TFLITEParser::Parse(Function* function,
214+
const std::vector<const char*>& buffers,
215+
const std::vector<size_t>& buffer_sizes) {
216+
return Status::ASSERTION;
217+
}
218+
200219
Status TFLITEParser::Parse(BasicBlock* bb, const tflite::Model& model) {
201220
RegisterOp();
202221
auto function = bb->GetParent();

lib/parser/tflite/tflite_parser.h

+4
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ class TFLITEParser : public Parser {
3838
Status Parse(Function* function, const std::vector<std::string>& file_list,
3939
const armory::Opts& opts) override;
4040
Status Parse(BasicBlock* bb, const tflite::Model& model);
41+
Status Parse(Function* function, const std::vector<const char*>& buffers,
42+
const std::vector<size_t>& buffer_sizes) override;
43+
Status Parse(Function* function,
44+
const std::vector<const void*>& model_defs) override;
4145

4246
template <typename T>
4347
static const Tensor<T> ProcessTensor(

0 commit comments

Comments
 (0)