Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] dynamic shape & cfg from pr736 #743

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
wip
  • Loading branch information
Weiming Zhao committed Dec 10, 2021
commit c95b119506493f0b194f1cac3084e73942c8ccb8
2 changes: 1 addition & 1 deletion .github/actions/build/build_in_docker.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/bin/bash -xe

REPO="registry-intl.us-west-1.aliyuncs.com/computation/halo"
VER="0.7.6"
VER="0.7.7"
FLAVOR="devel"

MOUNT_DIR="$PWD"
3 changes: 3 additions & 0 deletions ODLA/platforms/tensorrt/odla_tensorrt.cc
Original file line number Diff line number Diff line change
@@ -456,6 +456,9 @@ static odla_value CreateValue(T* t, const odla_value_type& type,
auto v = std::make_unique<_odla_value>(t, type, name);
auto ret = v.get();
g_comp->vals.push_back(std::move(v));
if (!g_comp->branchs.empty()) {
g_comp->branchs.top().branch->addInput(*ret);
}
return ret;
}

1 change: 1 addition & 0 deletions include/halo/lib/pass/pass_manager.h
Original file line number Diff line number Diff line change
@@ -67,6 +67,7 @@ class HL_API_EXPORT PassManager final {
Pass* AddCodeFormatterPass(std::ostringstream& code,
std::ostringstream& header,
const CXXCodeGenOpts& opts);
Pass* AddConvertTFCFGPass();
Pass* AddDCEPass();
Pass* AddDevicePlacementPass();
Pass* AddFusionPass(const FusionOptions& opts);
3 changes: 3 additions & 0 deletions include/halo/utils/passes_helper.h
Original file line number Diff line number Diff line change
@@ -162,6 +162,9 @@ static void PopulateOptPasses(PassManager* pm, const std::string& target,
if (opts.enable_type_cast) {
pm->AddTypeCastPass();
}
if (format == ModelFormat::TENSORFLOW) {
pm->AddConvertTFCFGPass();
}
if (opts.constant_decombine) {
pm->AddConstantDecombinePass();
}
3 changes: 3 additions & 0 deletions lib/pass/pass_manager.cc
Original file line number Diff line number Diff line change
@@ -31,6 +31,7 @@
#include "halo/lib/transforms/analyzer.h"
#include "halo/lib/transforms/caffeextension_legalizer.h"
#include "halo/lib/transforms/constant_decombine.h"
#include "halo/lib/transforms/convert_tf_cfg.h"
#include "halo/lib/transforms/dce.h"
#include "halo/lib/transforms/device_placement.h"
#include "halo/lib/transforms/fusion.h"
@@ -265,6 +266,8 @@ Pass* PassManager::AddCodeFormatterPass(std::ostringstream& buf_code,
return AddPass<CodeFormatter>(buf_code, buf_header, opts);
}

Pass* PassManager::AddConvertTFCFGPass() { return AddPass<ConvertTFCFG>(); }

Pass* PassManager::AddDCEPass() { return AddPass<DCE>(); }

Pass* PassManager::AddDevicePlacementPass() {
6 changes: 6 additions & 0 deletions lib/target/generic_cpp/generic_cxx_codegen.cc
Original file line number Diff line number Diff line change
@@ -270,6 +270,9 @@ std::string GenericCXXCodeGen::GetFunctionDecl(const Function& func,
auto nr_outputs = ret_inst.GetNumOfOperands();
model_info.num_outputs = nr_outputs;
for (const auto& out : ret_inst.GetOperands()) {
if (out.IsNull()) {
continue;
}
const auto& type = out.GetType();
if (ir_mapping_.find(out) == ir_mapping_.end()) {
CXXValue cv(out.GetDef()->GetName(), TensorTypeToCXXType(type, false));
@@ -840,6 +843,9 @@ void GenericCXXCodeGen::RunOnFunction(Function& function) {
index = 0;
// Pre-launch binding.
for (auto& op : return_inst->GetOperands()) {
if (op.IsNull()) {
continue;
}
auto& cv = ir_mapping_[op];
std::string arg_name = (opts_.emit_inference_func_sig || is_sub)
? (is_sub ? "outputs.values[" : "outputs[") +
3 changes: 3 additions & 0 deletions lib/target/generic_cpp/return.cc
Original file line number Diff line number Diff line change
@@ -25,6 +25,9 @@ namespace halo {
void GenericCXXCodeGen::RunOnInstruction(ReturnInst* inst) {
bool is_compile_mode = opts_.exec_mode == ExecMode::Compile;
for (auto& op : inst->GetOperands()) {
if (op.IsNull()) {
continue;
}
const CXXValue& val = ir_mapping_[op];
if (is_compile_mode) {
bool is_entry_with_calls =
1 change: 1 addition & 0 deletions lib/transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -22,6 +22,7 @@ set(SRCS
analyzer.cc
caffeextension_legalizer.cc
constant_decombine.cc
convert_tf_cfg.cc
dce.cc
device_placement.cc
fusion.cc
20 changes: 17 additions & 3 deletions lib/transforms/dce.cc
Original file line number Diff line number Diff line change
@@ -24,17 +24,28 @@

namespace halo {

static void RemoveLoopBody(LoopInst* loop_inst) {
auto body = loop_inst->GetBody();
auto return_inst = body->GetReturnInst();
static void RemoveBody(BasicBlock* bb) {
auto return_inst = bb->GetReturnInst();
if (return_inst != nullptr) {
// Drop all the operands of the return instruction so the rest of the body
// loop will be DCE'ed automatically.
// Note that the return inst cannot be erased because the current legalizer
// will try to append one if no return inst exists for a block.
return_inst->DropAllOperands();
if (bb->Instructions().size() == 1) {
bb->Instructions().clear();
return;
}
}
}
static void RemoveLoopBody(LoopInst* loop_inst) {
RemoveBody(loop_inst->GetBody());
}

static void RemoveIfBody(IfInst* if_inst) {
RemoveBody(if_inst->GetThenBranch());
RemoveBody(if_inst->GetElseBranch());
}

// For instructions with `undef` operands, they are unreachable except for
// `tf_merge` and optional operands.
@@ -85,6 +96,9 @@ bool DCE::RunOnBasicBlock(BasicBlock* bb) {
if (inst->GetOpCode() == OpCode::LOOP) {
RemoveLoopBody(DynCast<LoopInst>(inst));
}
if (inst->GetOpCode() == OpCode::IF) {
RemoveIfBody(DynCast<IfInst>(inst));
}
it = bb->Instructions().erase(it);
} else {
it = std::next(it);
1 change: 1 addition & 0 deletions lib/transforms/input_legalizer.cc
Original file line number Diff line number Diff line change
@@ -119,6 +119,7 @@ bool InputLegalizer::RunOnFunction(Function* func) {
: it->second.GetDimSizes();
arg->GetResultsTypes()[0] = halo::Type(dt, dims);
specified_shapes.erase(it);
changed = true;
}

auto dims = ty.GetDimSizes();
30 changes: 29 additions & 1 deletion lib/transforms/tfextension_legalizer.cc
Original file line number Diff line number Diff line change
@@ -489,15 +489,43 @@ static std::vector<Def> ConvertStridedSlice(const TFExtensionInst* ext,
static std::vector<Def> ConvertSwitch(const TFExtensionInst* ext,
IRBuilder* builder) {
const auto& data = ext->GetOperand(0);
if (const Constant* pred = DynCast<Constant>(ext->GetOperand(1));
const auto& cond = ext->GetOperand(1);
#if 0
if (const Constant* pred = DynCast<Constant>(cond);
pred != nullptr) {
HLCHECK(pred->GetResultType().GetTotalNumOfElements() == 1);
bool cond = pred->GetDataAsInt64(0) != 0;
std::vector<Def> ret_true{Def::GetUndefined(), data};
std::vector<Def> ret_false{data, Def::GetUndefined()};
return cond ? ret_true : ret_false;
}
#endif
// TODO(unknown): move to separate pass?
#if 1
builder->SetInsertAfter(ext);
BasicBlockBuilder bb_builder(ext->GetParent()->GetParent());
const auto& name = ext->GetName();
auto if_inst = builder->CreateIf(ext->GetName(), cond);
if_inst->AddOneOperand(data);

BasicBlock* bb_t = bb_builder.CreateBasicBlock(name + "_true");
if_inst->SetThenBranch(bb_t);
IRBuilder builder_t(bb_t);
auto arg_builder_t = std::make_unique<ArgumentBuilder>(bb_t);
auto arg_t = arg_builder_t->CreateArgument(name + "_t", data.GetType());
builder_t.CreateReturn(name + "ret_t", *arg_t);

BasicBlock* bb_f = bb_builder.CreateBasicBlock(name + "_false");
IRBuilder builder_f(bb_f);
if_inst->SetElseBranch(bb_f);
auto arg_builder_f = std::make_unique<ArgumentBuilder>(bb_f);
auto arg_f = arg_builder_f->CreateArgument(name + "_f", data.GetType());
builder_f.CreateReturn(name + "ret_f", *arg_f);
if_inst->SetNumOfResults(2);
return {Def(if_inst, 0), Def(if_inst, 1)};
#else
return {};
#endif
}

static std::vector<Def> ConvertMerge(const TFExtensionInst* ext,
24 changes: 24 additions & 0 deletions lib/transforms/type_legalizer.cc
Original file line number Diff line number Diff line change
@@ -1531,6 +1531,25 @@ static void RunOnInstruction(BitcastInst* inst) {
inst->GetResultsTypes()[0] = result_type;
}

static void RunOnInstruction(TFExtensionInst* inst) {
if (inst->GetExtOpCode() == TFExtOpCode::MERGE) {
for (auto& op : inst->GetOperands()) {
if (op.GetType().IsValid()) {
inst->GetResultsTypes()[0] = op.GetType();
return;
}
}
return;
}
if (inst->GetExtOpCode() == TFExtOpCode::SWITCH) {
const auto& ty = inst->GetOperand(0).GetType();
if (ty.IsValid()) {
inst->GetResultsTypes() = {ty, ty};
}
return;
}
}

static void RunOnInstruction(UniqueInst* inst) {
const auto& type0 = inst->GetOperand(0).GetType();
if (!type0.IsValid()) {
@@ -1569,6 +1588,11 @@ bool TypeLegalizer::RunOnBasicBlock(BasicBlock* bb) {
#define GET_INST_DOWNCAST_SWITCH
#include "halo/lib/ir/instructions_info.def"
#undef GET_INST_DOWNCAST_SWITCH
case OpCode::EXTENSION: {
TFExtensionInst* ext = DynCast<TFExtensionInst>(inst);
RunOnInstruction(ext);
break;
}
default: {
if (!relaxed_) {
// HLCHECK(0 && "Unreachable");
Loading
Oops, something went wrong.