Skip to content

Commit 43964fc

Browse files
Weiming Zhaoweimingzha0
Weiming Zhao
authored andcommitted
[Refactor] Decouple concrete passes from driver
1 parent a74a554 commit 43964fc

23 files changed

+545
-316
lines changed

armory/analyzer/driver.cc

+7-25
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,8 @@
2323
#include "halo/lib/ir/ir_builder.h"
2424
#include "halo/lib/parser/parser.h"
2525
#include "halo/lib/pass/pass_manager.h"
26-
#include "halo/lib/transforms/analyzer.h"
27-
#include "halo/lib/transforms/caffeextension_legalizer.h"
28-
#include "halo/lib/transforms/dce.h"
29-
#include "halo/lib/transforms/input_legalizer.h"
30-
#include "halo/lib/transforms/inst_simplify.h"
31-
#include "halo/lib/transforms/onnxextension_legalizer.h"
32-
#include "halo/lib/transforms/tfextension_legalizer.h"
33-
#include "halo/lib/transforms/type_legalizer.h"
3426
#include "halo/utils/cl_options.h"
27+
#include "halo/utils/passes_helper.h"
3528
#include "llvm/ADT/SmallVector.h"
3629
#include "llvm/Support/FileSystem.h"
3730
#include "llvm/Support/Path.h"
@@ -60,24 +53,13 @@ static void PopulatePassesAndRun(GlobalContext& ctx, Module& m,
6053
Parser::Format format) {
6154
PassManager pm(ctx);
6255
std::vector<std::string> input_shapes(InputsShape.begin(), InputsShape.end());
63-
pm.AddPass<InputLegalizer>(batch.getValue(), input_shapes,
64-
PreprocessScale.getValue());
65-
if (format == Parser::Format::CAFFE) {
66-
pm.AddPass<CAFFEExtensionLegalizer>();
67-
} else if (format == Parser::Format::TENSORFLOW) {
68-
pm.AddPass<TFExtensionLegalizer>();
69-
} else {
70-
HLCHECK(format == Parser::Format::ONNX);
71-
pm.AddPass<ONNXExtensionLegalizer>();
72-
}
73-
pm.AddPass<DCE>();
74-
pm.AddPass<TypeLegalizer>(true);
75-
pm.AddPass<InstSimplify>(true, true, false, false, false, false);
76-
auto analyzer = pm.AddPass<Analyzer>();
56+
Fusion::Options fusion_opts;
57+
Opts opts;
58+
PopulateOptPasses(&pm, "cxx", input_shapes, {}, {}, batch, "",
59+
ReorderChannel::ChannelOrder::None, false, false, format,
60+
opts, fusion_opts);
61+
pm.AddAnalyzerPass(&std::cout);
7762
pm.Run(&m);
78-
if (PrintAnalysisReport) {
79-
analyzer->WriteCSVReport(std::cout);
80-
}
8163
}
8264

8365
int main(int argc, char** argv) {

driver/driver.cc

+41-193
Original file line numberDiff line numberDiff line change
@@ -23,29 +23,10 @@
2323
#include "halo/lib/ir/ir_builder.h"
2424
#include "halo/lib/parser/parser.h"
2525
#include "halo/lib/pass/pass_manager.h"
26-
#include "halo/lib/quantizer/weights_quantizer.h"
27-
#include "halo/lib/target/cpu/arm/binary/arm_llvmir_codegen.h"
28-
#include "halo/lib/target/cpu/riscv/binary/riscv_llvmir_codegen.h"
29-
#include "halo/lib/target/cpu/x86/binary/x86_llvmir_codegen.h"
30-
#include "halo/lib/target/generic_cxx/generic_cxx_codegen.h"
31-
#include "halo/lib/target/generic_llvmir/generic_llvmir_codegen.h"
32-
#include "halo/lib/target/triton/triton_config_writer.h"
33-
#include "halo/lib/transforms/caffeextension_legalizer.h"
34-
#include "halo/lib/transforms/dce.h"
35-
#include "halo/lib/transforms/device_placement.h"
3626
#include "halo/lib/transforms/fusion.h"
37-
#include "halo/lib/transforms/input_legalizer.h"
38-
#include "halo/lib/transforms/input_rewriter.h"
39-
#include "halo/lib/transforms/inst_simplify.h"
40-
#include "halo/lib/transforms/onnxextension_legalizer.h"
41-
#include "halo/lib/transforms/output_rewriter.h"
4227
#include "halo/lib/transforms/reorder_channel.h"
43-
#include "halo/lib/transforms/splitting.h"
44-
#include "halo/lib/transforms/tfextension_legalizer.h"
45-
#include "halo/lib/transforms/tfliteextension_legalizer.h"
46-
#include "halo/lib/transforms/type_legalizer.h"
47-
#include "halo/lib/transforms/typecast.h"
4828
#include "halo/utils/cl_options.h"
29+
#include "halo/utils/passes_helper.h"
4930
#include "halo/version.h"
5031
#include "llvm/ADT/SmallVector.h"
5132
#include "llvm/ADT/StringSwitch.h"
@@ -248,177 +229,6 @@ static llvm::cl::opt<bool> CheckModel("check-model",
248229
#include "halo/lib/ir/fusion.cc.inc"
249230
#undef HALO_FUSION_CMD_OPTIONS_DECL
250231

251-
static void PopulateCodeGenPasses(PassManager* pm, std::ostream* out_code,
252-
std::ostream* out_constants,
253-
std::ostream* out_header,
254-
std::ostream* out_dynamic_check,
255-
bool is_c_or_cxx_output,
256-
bool is_binary_output) {
257-
auto constant_storage =
258-
GenericLLVMIRCodeGen::ConstantDataStorage::DefinedAsStatic;
259-
if (SeparateConstants) {
260-
constant_storage =
261-
GenericLLVMIRCodeGen::ConstantDataStorage::DeclaredAsExternal;
262-
}
263-
264-
CodeGen* cg = nullptr;
265-
if (is_c_or_cxx_output) {
266-
Opts opts(BF16Mode);
267-
if (llvm::StringRef(Target).startswith_lower("cc")) {
268-
opts.dialect = Dialect::C99;
269-
}
270-
opts.print_mem_stats = PrintMemStats;
271-
opts.emit_value_reset = EmitValueReset;
272-
opts.exec_mode = ExecMode.getValue();
273-
opts.emit_value_id_as_int = EmitValueIDAsInt;
274-
opts.emit_inference_func_sig = EmitInferenceFunctionSignature;
275-
opts.emit_dynamic_batch = (Batch.getValue() == kDynamicBatchSize);
276-
opts.fp16_mode = EnableFP16;
277-
opts.max_batch_size = MaxBatch.getValue();
278-
opts.min_batch_size = MinBatch.getValue();
279-
opts.opt_batch_size = OptBatch.getValue();
280-
opts.check_model = CheckModel;
281-
opts.enable_ipu_device = EnableIpuDevice;
282-
opts.use_ipu_model = UseIpuModel;
283-
opts.ipu_num = IpuNum;
284-
opts.batches_per_step = BatchesPerStep;
285-
286-
pm->AddPass<WeightsQuantizer>(QuantWeights.getValue(), PGQFile.getValue());
287-
cg = pm->AddPass<GenericCXXCodeGen>(std::ref(*out_code),
288-
std::ref(*out_header),
289-
std::ref(*out_dynamic_check), opts);
290-
cg->SetAPI(Api);
291-
292-
if (EmitDataAsC) {
293-
pm->AddPass<GenericCXXConstantWriter>(std::ref(*out_constants));
294-
} else {
295-
pm->AddPass<X86ConstantWriter>(std::ref(*out_constants));
296-
}
297-
if (EmitTritonConfig) {
298-
pm->AddPass<TritonConfigWriter>(
299-
TritonConfigFile.getValue(),
300-
opts.emit_dynamic_batch ? MaxBatch.getValue() : 0);
301-
}
302-
return;
303-
}
304-
305-
if (EmitLLVMIR) {
306-
pm->AddPass<WeightsQuantizer>(QuantWeights.getValue(), PGQFile.getValue());
307-
cg = pm->AddPass<GenericLLVMIRCodeGen>(constant_storage);
308-
pm->AddPass<GenericLLVMIRWriter>(std::ref(*out_code), is_binary_output);
309-
if (SeparateConstants && !EmitCodeOnly) {
310-
pm->AddPass<GenericConstantWriter>(std::ref(*out_constants),
311-
is_binary_output);
312-
}
313-
} else {
314-
llvm::Triple triple(Target);
315-
switch (triple.getArch()) {
316-
case llvm::Triple::ArchType::x86:
317-
case llvm::Triple::ArchType::x86_64: {
318-
pm->AddPass<X86LLVMIRCodeGen>(
319-
GenericLLVMIRCodeGen::ConstantDataStorage::DeclaredAsExternal);
320-
pm->AddPass<X86BinaryWriter>(std::ref(*out_code));
321-
if (SeparateConstants && !EmitCodeOnly) {
322-
pm->AddPass<WeightsQuantizer>(QuantWeights.getValue(),
323-
PGQFile.getValue());
324-
pm->AddPass<X86ConstantWriter>(std::ref(*out_constants));
325-
}
326-
break;
327-
}
328-
case llvm::Triple::ArchType::aarch64: {
329-
pm->AddPass<ARMLLVMIRCodeGen>(
330-
GenericLLVMIRCodeGen::ConstantDataStorage::DeclaredAsExternal);
331-
pm->AddPass<ARMBinaryWriter>(std::ref(*out_code));
332-
if (SeparateConstants && !EmitCodeOnly) {
333-
pm->AddPass<WeightsQuantizer>(QuantWeights.getValue(),
334-
PGQFile.getValue());
335-
pm->AddPass<ARMConstantWriter>(std::ref(*out_constants));
336-
}
337-
break;
338-
}
339-
case llvm::Triple::ArchType::riscv32:
340-
case llvm::Triple::ArchType::riscv64: {
341-
if (RISCVOpt) {
342-
pm->AddPass<RISCVLLVMIRCodeGen>(
343-
GenericLLVMIRCodeGen::ConstantDataStorage::DeclaredAsExternal,
344-
"libRT_RISCV.a");
345-
} else {
346-
pm->AddPass<RISCVLLVMIRCodeGen>(
347-
GenericLLVMIRCodeGen::ConstantDataStorage::DeclaredAsExternal);
348-
}
349-
pm->AddPass<RISCVBinaryWriter>(std::ref(*out_code));
350-
if (SeparateConstants && !EmitCodeOnly) {
351-
pm->AddPass<WeightsQuantizer>(QuantWeights.getValue(),
352-
PGQFile.getValue());
353-
pm->AddPass<RISCVConstantWriter>(std::ref(*out_constants));
354-
}
355-
356-
break;
357-
}
358-
359-
default: {
360-
HLCHECK(0 && "Unsupported");
361-
}
362-
}
363-
}
364-
if (cg != nullptr) {
365-
cg->SetAPI(Api);
366-
}
367-
}
368-
369-
static void PopulatePasses(PassManager* pm, std::ostream* out_code,
370-
std::ostream* out_constants,
371-
std::ostream* out_header,
372-
std::ostream* out_dynamic_check,
373-
bool is_c_or_cxx_output, bool is_binary_output,
374-
Parser::Format format) {
375-
std::vector<std::string> input_shapes(InputsShape.begin(), InputsShape.end());
376-
pm->AddPass<InputLegalizer>(Batch.getValue(), input_shapes,
377-
PreprocessScale.getValue());
378-
if (!Outputs.empty()) {
379-
std::vector<std::string> outputs(Outputs.begin(), Outputs.end());
380-
pm->AddPass<OutputRewriter>(outputs);
381-
}
382-
if (format == Parser::Format::CAFFE) {
383-
pm->AddPass<CAFFEExtensionLegalizer>();
384-
} else if (format == Parser::Format::TENSORFLOW) {
385-
pm->AddPass<TFExtensionLegalizer>();
386-
} else if (format == Parser::Format::TFLITE) {
387-
HLCHECK(format == Parser::Format::TFLITE);
388-
pm->AddPass<TFLITEExtensionLegalizer>();
389-
} else {
390-
HLCHECK(format == Parser::Format::ONNX);
391-
pm->AddPass<ONNXExtensionLegalizer>();
392-
}
393-
pm->AddPass<DCE>();
394-
pm->AddPass<TypeLegalizer>(true);
395-
if (!Inputs.empty()) {
396-
std::vector<std::string> inputs(Inputs.begin(), Inputs.end());
397-
pm->AddPass<InputRewriter>(inputs);
398-
}
399-
auto fusion_opts = GetFusionOptions();
400-
pm->AddPass<InstSimplify>(
401-
llvm::StringRef(Target).startswith("cxx"), DisableBroadcasting.getValue(),
402-
RemoveInputTranspose.getValue(), RemoveOutputTranspose.getValue(),
403-
DisableConvBN.getValue(), fusion_opts.ConvBias);
404-
if (ReorderChannelLayout != ReorderChannel::ChannelOrder::None) {
405-
pm->AddPass<ReorderChannel>(ReorderChannelLayout ==
406-
ReorderChannel::ChannelOrder::ChannelFirst);
407-
}
408-
pm->AddPass<Fusion>(fusion_opts);
409-
if (SplitFunction) {
410-
pm->AddPass<Splitting>();
411-
pm->AddPass<DevicePlacement>();
412-
}
413-
if (!DisableTypeCast) {
414-
pm->AddPass<TypeCast>();
415-
}
416-
417-
PopulateCodeGenPasses(pm, out_code, out_constants, out_header,
418-
out_dynamic_check, is_c_or_cxx_output,
419-
is_binary_output);
420-
}
421-
422232
static bool FormatCode(const std::string& filename) {
423233
if (filename.empty() || filename == "-") {
424234
return false;
@@ -540,11 +350,49 @@ int main(int argc, char** argv) {
540350
out_dynamic_check = &of_dynamic_check;
541351
}
542352

543-
PopulatePasses(&pm, out_code, out_constants, out_header, out_dynamic_check,
544-
is_c_or_cxx_output, is_binary_output, format);
353+
Opts cg_opts;
354+
cg_opts.bf16_mode = BF16Mode;
355+
cg_opts.print_mem_stats = PrintMemStats;
356+
cg_opts.emit_value_reset = EmitValueReset;
357+
cg_opts.exec_mode = ExecMode.getValue();
358+
cg_opts.emit_value_id_as_int = EmitValueIDAsInt;
359+
cg_opts.emit_inference_func_sig = EmitInferenceFunctionSignature;
360+
cg_opts.emit_dynamic_batch = (Batch.getValue() == kDynamicBatchSize);
361+
cg_opts.fp16_mode = EnableFP16;
362+
cg_opts.max_batch_size = MaxBatch.getValue();
363+
cg_opts.min_batch_size = MinBatch.getValue();
364+
cg_opts.opt_batch_size = OptBatch.getValue();
365+
cg_opts.check_model = CheckModel;
366+
cg_opts.enable_ipu_device = EnableIpuDevice;
367+
cg_opts.use_ipu_model = UseIpuModel;
368+
cg_opts.ipu_num = IpuNum;
369+
cg_opts.batches_per_step = BatchesPerStep;
370+
cg_opts.api = Api;
371+
cg_opts.disable_broadcasting = DisableBroadcasting;
372+
cg_opts.separate_constants = SeparateConstants;
373+
cg_opts.disable_conv_bn = DisableConvBN;
374+
cg_opts.remove_input_transpose = RemoveInputTranspose;
375+
cg_opts.remove_output_transpose = RemoveOutputTranspose;
376+
545377
if (is_c_or_cxx_output) {
546378
ctx.SetTargetTriple("x86_64"); // For binary constant writer.
379+
if (llvm::StringRef(Target).startswith_lower("cc")) {
380+
cg_opts.dialect = Dialect::C99;
381+
}
547382
}
383+
std::vector<std::string> input_shapes(InputsShape.begin(), InputsShape.end());
384+
std::vector<std::string> inputs(Inputs.begin(), Inputs.end());
385+
std::vector<std::string> outputs(Outputs.begin(), Outputs.end());
386+
const auto& fusion_opts = GetFusionOptions();
387+
388+
PopulateOptPasses(&pm, Target, input_shapes, inputs, outputs, Batch,
389+
PreprocessScale, ReorderChannelLayout, SplitFunction,
390+
DisableTypeCast, format, cg_opts, fusion_opts);
391+
PopulateCodeGenPasses(&pm, out_code, out_constants, out_header,
392+
out_dynamic_check, Target, is_c_or_cxx_output,
393+
is_binary_output, EmitDataAsC, EmitCodeOnly, EmitLLVMIR,
394+
EmitTritonConfig, TritonConfigFile, QuantWeights,
395+
PGQFile, RISCVOpt, cg_opts);
548396

549397
auto status = pm.Run(&m);
550398

include/halo/lib/pass/pass_manager.h

+62-1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@
2525
#include "halo/lib/framework/global_context.h"
2626
#include "halo/lib/ir/module.h"
2727
#include "halo/lib/pass/pass.h"
28+
#include "halo/lib/target/generic_cxx/generic_cxx_codegen.h"
29+
#include "halo/lib/target/generic_llvmir/generic_llvmir_codegen.h"
30+
#include "halo/lib/transforms/fusion.h"
2831

2932
namespace halo {
3033

@@ -39,7 +42,7 @@ class PassManager final {
3942

4043
/// Add a pass to the pass manager.
4144
template <typename T, typename... TS>
42-
T* AddPass(TS... args) {
45+
T* AddPass(TS&... args) {
4346
auto pass = std::make_unique<T>(args...);
4447
T* ret = static_cast<T*>(pass.get());
4548
Add(std::move(pass));
@@ -53,6 +56,64 @@ class PassManager final {
5356
void Print(std::ostream& os) const;
5457

5558
void Dump() const;
59+
Pass* AddAnalyzerPass(std::ostream* os);
60+
Pass* AddARMBinaryWriterPass(std::ostream& os);
61+
Pass* AddARMConstantWriterPass(std::ostream& os);
62+
Pass* AddARMLLVMIRCodeGenPass(
63+
GenericLLVMIRCodeGen::ConstantDataStorage constant_data_storage);
64+
Pass* AddCAFFEExtensionLegalizerPass();
65+
Pass* AddDCEPass();
66+
Pass* AddDevicePlacementPass();
67+
Pass* AddFusionPass(const Fusion::Options& opts);
68+
Pass* AddGenericConstantWriterPass(std::ostream& os, bool bitcode_format);
69+
Pass* AddGenericCXXConstantWriterPass(std::ostream& os);
70+
Pass* AddGenericCXXCodeGenPass(std::ostream& os, std::ostream& header_os);
71+
Pass* AddGenericCXXCodeGenPass(std::ostream& os, std::ostream& header_os,
72+
std::ostream& dynamic_check_os,
73+
const Opts& opts);
74+
Pass* AddGenericLLVMIRCodeGenPass();
75+
Pass* AddGenericLLVMIRCodeGenPass(
76+
GenericLLVMIRCodeGen::ConstantDataStorage constant_data_storage);
77+
Pass* AddGenericLLVMIRCodeGenPass(
78+
const std::string& name,
79+
GenericLLVMIRCodeGen::ConstantDataStorage constant_data_storage);
80+
Pass* AddGenericLLVMIRWriterPass(std::ostream& os, bool bitcode_format);
81+
Pass* AddInputLegalizerPass(int batch_size,
82+
const std::vector<std::string>& inputs_shapes,
83+
const std::string& scale_str);
84+
Pass* AddInputRewriterPass(const std::vector<std::string>& inputs);
85+
Pass* AddInstSimplifyPass();
86+
Pass* AddInstSimplifyPass(bool simplify_for_preprocess,
87+
bool disable_broadcasting,
88+
bool remove_input_transpose,
89+
bool remove_output_transpose, bool disable_conv_bn,
90+
bool fuse_conv_bias);
91+
Pass* AddONNXExtensionLegalizerPass();
92+
Pass* AddOutputRewriterPass(const std::vector<std::string>& outputs);
93+
Pass* AddReorderChannelPass(bool channel_first);
94+
Pass* AddRISCVBinaryWriterPass(std::ostream& os);
95+
Pass* AddRISCVConstantWriterPass(std::ostream& os);
96+
Pass* AddRISCVLLVMIRCodeGenPass(
97+
GenericLLVMIRCodeGen::ConstantDataStorage constant_data_storage);
98+
99+
Pass* AddRISCVLLVMIRCodeGenPass(
100+
GenericLLVMIRCodeGen::ConstantDataStorage constant_data_storage,
101+
std::string rt_lib_name);
102+
Pass* AddSplittingPass();
103+
Pass* AddTFExtensionLegalizerPass();
104+
Pass* AddTFLiteExtensionLegalizerPass();
105+
Pass* AddTritonConfigWriterPass(const std::string& filename,
106+
int max_batch_size);
107+
Pass* AddTypeCastPass();
108+
Pass* AddTypeLegalizerPass();
109+
Pass* AddTypeLegalizerPass(bool relaxed);
110+
Pass* AddWeightsQuantizerPass(CodeGen::Quantization quant,
111+
const std::string& file);
112+
Pass* AddX86BinaryWriterPass(std::ostream& os);
113+
Pass* AddX86ConstantWriterPass(std::ostream& os);
114+
Pass* AddX86LLVMIRCodeGenPass();
115+
Pass* AddX86LLVMIRCodeGenPass(
116+
GenericLLVMIRCodeGen::ConstantDataStorage constant_data_storage);
56117

57118
private:
58119
Pass* Add(std::unique_ptr<ModulePass> pass);

0 commit comments

Comments
 (0)