|
23 | 23 | #include "halo/lib/ir/ir_builder.h"
|
24 | 24 | #include "halo/lib/parser/parser.h"
|
25 | 25 | #include "halo/lib/pass/pass_manager.h"
|
26 |
| -#include "halo/lib/quantizer/weights_quantizer.h" |
27 |
| -#include "halo/lib/target/cpu/arm/binary/arm_llvmir_codegen.h" |
28 |
| -#include "halo/lib/target/cpu/riscv/binary/riscv_llvmir_codegen.h" |
29 |
| -#include "halo/lib/target/cpu/x86/binary/x86_llvmir_codegen.h" |
30 |
| -#include "halo/lib/target/generic_cxx/generic_cxx_codegen.h" |
31 |
| -#include "halo/lib/target/generic_llvmir/generic_llvmir_codegen.h" |
32 |
| -#include "halo/lib/target/triton/triton_config_writer.h" |
33 |
| -#include "halo/lib/transforms/caffeextension_legalizer.h" |
34 |
| -#include "halo/lib/transforms/dce.h" |
35 |
| -#include "halo/lib/transforms/device_placement.h" |
36 | 26 | #include "halo/lib/transforms/fusion.h"
|
37 |
| -#include "halo/lib/transforms/input_legalizer.h" |
38 |
| -#include "halo/lib/transforms/input_rewriter.h" |
39 |
| -#include "halo/lib/transforms/inst_simplify.h" |
40 |
| -#include "halo/lib/transforms/onnxextension_legalizer.h" |
41 |
| -#include "halo/lib/transforms/output_rewriter.h" |
42 | 27 | #include "halo/lib/transforms/reorder_channel.h"
|
43 |
| -#include "halo/lib/transforms/splitting.h" |
44 |
| -#include "halo/lib/transforms/tfextension_legalizer.h" |
45 |
| -#include "halo/lib/transforms/tfliteextension_legalizer.h" |
46 |
| -#include "halo/lib/transforms/type_legalizer.h" |
47 |
| -#include "halo/lib/transforms/typecast.h" |
48 | 28 | #include "halo/utils/cl_options.h"
|
| 29 | +#include "halo/utils/passes_helper.h" |
49 | 30 | #include "halo/version.h"
|
50 | 31 | #include "llvm/ADT/SmallVector.h"
|
51 | 32 | #include "llvm/ADT/StringSwitch.h"
|
@@ -248,177 +229,6 @@ static llvm::cl::opt<bool> CheckModel("check-model",
|
248 | 229 | #include "halo/lib/ir/fusion.cc.inc"
|
249 | 230 | #undef HALO_FUSION_CMD_OPTIONS_DECL
|
250 | 231 |
|
251 |
| -static void PopulateCodeGenPasses(PassManager* pm, std::ostream* out_code, |
252 |
| - std::ostream* out_constants, |
253 |
| - std::ostream* out_header, |
254 |
| - std::ostream* out_dynamic_check, |
255 |
| - bool is_c_or_cxx_output, |
256 |
| - bool is_binary_output) { |
257 |
| - auto constant_storage = |
258 |
| - GenericLLVMIRCodeGen::ConstantDataStorage::DefinedAsStatic; |
259 |
| - if (SeparateConstants) { |
260 |
| - constant_storage = |
261 |
| - GenericLLVMIRCodeGen::ConstantDataStorage::DeclaredAsExternal; |
262 |
| - } |
263 |
| - |
264 |
| - CodeGen* cg = nullptr; |
265 |
| - if (is_c_or_cxx_output) { |
266 |
| - Opts opts(BF16Mode); |
267 |
| - if (llvm::StringRef(Target).startswith_lower("cc")) { |
268 |
| - opts.dialect = Dialect::C99; |
269 |
| - } |
270 |
| - opts.print_mem_stats = PrintMemStats; |
271 |
| - opts.emit_value_reset = EmitValueReset; |
272 |
| - opts.exec_mode = ExecMode.getValue(); |
273 |
| - opts.emit_value_id_as_int = EmitValueIDAsInt; |
274 |
| - opts.emit_inference_func_sig = EmitInferenceFunctionSignature; |
275 |
| - opts.emit_dynamic_batch = (Batch.getValue() == kDynamicBatchSize); |
276 |
| - opts.fp16_mode = EnableFP16; |
277 |
| - opts.max_batch_size = MaxBatch.getValue(); |
278 |
| - opts.min_batch_size = MinBatch.getValue(); |
279 |
| - opts.opt_batch_size = OptBatch.getValue(); |
280 |
| - opts.check_model = CheckModel; |
281 |
| - opts.enable_ipu_device = EnableIpuDevice; |
282 |
| - opts.use_ipu_model = UseIpuModel; |
283 |
| - opts.ipu_num = IpuNum; |
284 |
| - opts.batches_per_step = BatchesPerStep; |
285 |
| - |
286 |
| - pm->AddPass<WeightsQuantizer>(QuantWeights.getValue(), PGQFile.getValue()); |
287 |
| - cg = pm->AddPass<GenericCXXCodeGen>(std::ref(*out_code), |
288 |
| - std::ref(*out_header), |
289 |
| - std::ref(*out_dynamic_check), opts); |
290 |
| - cg->SetAPI(Api); |
291 |
| - |
292 |
| - if (EmitDataAsC) { |
293 |
| - pm->AddPass<GenericCXXConstantWriter>(std::ref(*out_constants)); |
294 |
| - } else { |
295 |
| - pm->AddPass<X86ConstantWriter>(std::ref(*out_constants)); |
296 |
| - } |
297 |
| - if (EmitTritonConfig) { |
298 |
| - pm->AddPass<TritonConfigWriter>( |
299 |
| - TritonConfigFile.getValue(), |
300 |
| - opts.emit_dynamic_batch ? MaxBatch.getValue() : 0); |
301 |
| - } |
302 |
| - return; |
303 |
| - } |
304 |
| - |
305 |
| - if (EmitLLVMIR) { |
306 |
| - pm->AddPass<WeightsQuantizer>(QuantWeights.getValue(), PGQFile.getValue()); |
307 |
| - cg = pm->AddPass<GenericLLVMIRCodeGen>(constant_storage); |
308 |
| - pm->AddPass<GenericLLVMIRWriter>(std::ref(*out_code), is_binary_output); |
309 |
| - if (SeparateConstants && !EmitCodeOnly) { |
310 |
| - pm->AddPass<GenericConstantWriter>(std::ref(*out_constants), |
311 |
| - is_binary_output); |
312 |
| - } |
313 |
| - } else { |
314 |
| - llvm::Triple triple(Target); |
315 |
| - switch (triple.getArch()) { |
316 |
| - case llvm::Triple::ArchType::x86: |
317 |
| - case llvm::Triple::ArchType::x86_64: { |
318 |
| - pm->AddPass<X86LLVMIRCodeGen>( |
319 |
| - GenericLLVMIRCodeGen::ConstantDataStorage::DeclaredAsExternal); |
320 |
| - pm->AddPass<X86BinaryWriter>(std::ref(*out_code)); |
321 |
| - if (SeparateConstants && !EmitCodeOnly) { |
322 |
| - pm->AddPass<WeightsQuantizer>(QuantWeights.getValue(), |
323 |
| - PGQFile.getValue()); |
324 |
| - pm->AddPass<X86ConstantWriter>(std::ref(*out_constants)); |
325 |
| - } |
326 |
| - break; |
327 |
| - } |
328 |
| - case llvm::Triple::ArchType::aarch64: { |
329 |
| - pm->AddPass<ARMLLVMIRCodeGen>( |
330 |
| - GenericLLVMIRCodeGen::ConstantDataStorage::DeclaredAsExternal); |
331 |
| - pm->AddPass<ARMBinaryWriter>(std::ref(*out_code)); |
332 |
| - if (SeparateConstants && !EmitCodeOnly) { |
333 |
| - pm->AddPass<WeightsQuantizer>(QuantWeights.getValue(), |
334 |
| - PGQFile.getValue()); |
335 |
| - pm->AddPass<ARMConstantWriter>(std::ref(*out_constants)); |
336 |
| - } |
337 |
| - break; |
338 |
| - } |
339 |
| - case llvm::Triple::ArchType::riscv32: |
340 |
| - case llvm::Triple::ArchType::riscv64: { |
341 |
| - if (RISCVOpt) { |
342 |
| - pm->AddPass<RISCVLLVMIRCodeGen>( |
343 |
| - GenericLLVMIRCodeGen::ConstantDataStorage::DeclaredAsExternal, |
344 |
| - "libRT_RISCV.a"); |
345 |
| - } else { |
346 |
| - pm->AddPass<RISCVLLVMIRCodeGen>( |
347 |
| - GenericLLVMIRCodeGen::ConstantDataStorage::DeclaredAsExternal); |
348 |
| - } |
349 |
| - pm->AddPass<RISCVBinaryWriter>(std::ref(*out_code)); |
350 |
| - if (SeparateConstants && !EmitCodeOnly) { |
351 |
| - pm->AddPass<WeightsQuantizer>(QuantWeights.getValue(), |
352 |
| - PGQFile.getValue()); |
353 |
| - pm->AddPass<RISCVConstantWriter>(std::ref(*out_constants)); |
354 |
| - } |
355 |
| - |
356 |
| - break; |
357 |
| - } |
358 |
| - |
359 |
| - default: { |
360 |
| - HLCHECK(0 && "Unsupported"); |
361 |
| - } |
362 |
| - } |
363 |
| - } |
364 |
| - if (cg != nullptr) { |
365 |
| - cg->SetAPI(Api); |
366 |
| - } |
367 |
| -} |
368 |
| - |
369 |
| -static void PopulatePasses(PassManager* pm, std::ostream* out_code, |
370 |
| - std::ostream* out_constants, |
371 |
| - std::ostream* out_header, |
372 |
| - std::ostream* out_dynamic_check, |
373 |
| - bool is_c_or_cxx_output, bool is_binary_output, |
374 |
| - Parser::Format format) { |
375 |
| - std::vector<std::string> input_shapes(InputsShape.begin(), InputsShape.end()); |
376 |
| - pm->AddPass<InputLegalizer>(Batch.getValue(), input_shapes, |
377 |
| - PreprocessScale.getValue()); |
378 |
| - if (!Outputs.empty()) { |
379 |
| - std::vector<std::string> outputs(Outputs.begin(), Outputs.end()); |
380 |
| - pm->AddPass<OutputRewriter>(outputs); |
381 |
| - } |
382 |
| - if (format == Parser::Format::CAFFE) { |
383 |
| - pm->AddPass<CAFFEExtensionLegalizer>(); |
384 |
| - } else if (format == Parser::Format::TENSORFLOW) { |
385 |
| - pm->AddPass<TFExtensionLegalizer>(); |
386 |
| - } else if (format == Parser::Format::TFLITE) { |
387 |
| - HLCHECK(format == Parser::Format::TFLITE); |
388 |
| - pm->AddPass<TFLITEExtensionLegalizer>(); |
389 |
| - } else { |
390 |
| - HLCHECK(format == Parser::Format::ONNX); |
391 |
| - pm->AddPass<ONNXExtensionLegalizer>(); |
392 |
| - } |
393 |
| - pm->AddPass<DCE>(); |
394 |
| - pm->AddPass<TypeLegalizer>(true); |
395 |
| - if (!Inputs.empty()) { |
396 |
| - std::vector<std::string> inputs(Inputs.begin(), Inputs.end()); |
397 |
| - pm->AddPass<InputRewriter>(inputs); |
398 |
| - } |
399 |
| - auto fusion_opts = GetFusionOptions(); |
400 |
| - pm->AddPass<InstSimplify>( |
401 |
| - llvm::StringRef(Target).startswith("cxx"), DisableBroadcasting.getValue(), |
402 |
| - RemoveInputTranspose.getValue(), RemoveOutputTranspose.getValue(), |
403 |
| - DisableConvBN.getValue(), fusion_opts.ConvBias); |
404 |
| - if (ReorderChannelLayout != ReorderChannel::ChannelOrder::None) { |
405 |
| - pm->AddPass<ReorderChannel>(ReorderChannelLayout == |
406 |
| - ReorderChannel::ChannelOrder::ChannelFirst); |
407 |
| - } |
408 |
| - pm->AddPass<Fusion>(fusion_opts); |
409 |
| - if (SplitFunction) { |
410 |
| - pm->AddPass<Splitting>(); |
411 |
| - pm->AddPass<DevicePlacement>(); |
412 |
| - } |
413 |
| - if (!DisableTypeCast) { |
414 |
| - pm->AddPass<TypeCast>(); |
415 |
| - } |
416 |
| - |
417 |
| - PopulateCodeGenPasses(pm, out_code, out_constants, out_header, |
418 |
| - out_dynamic_check, is_c_or_cxx_output, |
419 |
| - is_binary_output); |
420 |
| -} |
421 |
| - |
422 | 232 | static bool FormatCode(const std::string& filename) {
|
423 | 233 | if (filename.empty() || filename == "-") {
|
424 | 234 | return false;
|
@@ -540,11 +350,49 @@ int main(int argc, char** argv) {
|
540 | 350 | out_dynamic_check = &of_dynamic_check;
|
541 | 351 | }
|
542 | 352 |
|
543 |
| - PopulatePasses(&pm, out_code, out_constants, out_header, out_dynamic_check, |
544 |
| - is_c_or_cxx_output, is_binary_output, format); |
| 353 | + Opts cg_opts; |
| 354 | + cg_opts.bf16_mode = BF16Mode; |
| 355 | + cg_opts.print_mem_stats = PrintMemStats; |
| 356 | + cg_opts.emit_value_reset = EmitValueReset; |
| 357 | + cg_opts.exec_mode = ExecMode.getValue(); |
| 358 | + cg_opts.emit_value_id_as_int = EmitValueIDAsInt; |
| 359 | + cg_opts.emit_inference_func_sig = EmitInferenceFunctionSignature; |
| 360 | + cg_opts.emit_dynamic_batch = (Batch.getValue() == kDynamicBatchSize); |
| 361 | + cg_opts.fp16_mode = EnableFP16; |
| 362 | + cg_opts.max_batch_size = MaxBatch.getValue(); |
| 363 | + cg_opts.min_batch_size = MinBatch.getValue(); |
| 364 | + cg_opts.opt_batch_size = OptBatch.getValue(); |
| 365 | + cg_opts.check_model = CheckModel; |
| 366 | + cg_opts.enable_ipu_device = EnableIpuDevice; |
| 367 | + cg_opts.use_ipu_model = UseIpuModel; |
| 368 | + cg_opts.ipu_num = IpuNum; |
| 369 | + cg_opts.batches_per_step = BatchesPerStep; |
| 370 | + cg_opts.api = Api; |
| 371 | + cg_opts.disable_broadcasting = DisableBroadcasting; |
| 372 | + cg_opts.separate_constants = SeparateConstants; |
| 373 | + cg_opts.disable_conv_bn = DisableConvBN; |
| 374 | + cg_opts.remove_input_transpose = RemoveInputTranspose; |
| 375 | + cg_opts.remove_output_transpose = RemoveOutputTranspose; |
| 376 | + |
545 | 377 | if (is_c_or_cxx_output) {
|
546 | 378 | ctx.SetTargetTriple("x86_64"); // For binary constant writer.
|
| 379 | + if (llvm::StringRef(Target).startswith_lower("cc")) { |
| 380 | + cg_opts.dialect = Dialect::C99; |
| 381 | + } |
547 | 382 | }
|
| 383 | + std::vector<std::string> input_shapes(InputsShape.begin(), InputsShape.end()); |
| 384 | + std::vector<std::string> inputs(Inputs.begin(), Inputs.end()); |
| 385 | + std::vector<std::string> outputs(Outputs.begin(), Outputs.end()); |
| 386 | + const auto& fusion_opts = GetFusionOptions(); |
| 387 | + |
| 388 | + PopulateOptPasses(&pm, Target, input_shapes, inputs, outputs, Batch, |
| 389 | + PreprocessScale, ReorderChannelLayout, SplitFunction, |
| 390 | + DisableTypeCast, format, cg_opts, fusion_opts); |
| 391 | + PopulateCodeGenPasses(&pm, out_code, out_constants, out_header, |
| 392 | + out_dynamic_check, Target, is_c_or_cxx_output, |
| 393 | + is_binary_output, EmitDataAsC, EmitCodeOnly, EmitLLVMIR, |
| 394 | + EmitTritonConfig, TritonConfigFile, QuantWeights, |
| 395 | + PGQFile, RISCVOpt, cg_opts); |
548 | 396 |
|
549 | 397 | auto status = pm.Run(&m);
|
550 | 398 |
|
|
0 commit comments