Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] dynamic shape & cfg from pr736 #743

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Prev Previous commit
Next Next commit
add Stack Op
  • Loading branch information
Weiming Zhao committed Dec 14, 2021
commit c8ca9428f05d99583af82a0499b7e7a4674f5169
3 changes: 2 additions & 1 deletion include/halo/lib/target/generic_cxx/generic_cxx_codegen.h
Original file line number Diff line number Diff line change
@@ -180,12 +180,13 @@ class GenericCXXCodeGen : public CodeGen {
virtual void RunOnInstruction(ShapeInst*) override;
virtual void RunOnInstruction(ShiftInst*) override;
virtual void RunOnInstruction(ShrinkInst*) override;
virtual void RunOnInstruction(SigmoidInst*) override;
virtual void RunOnInstruction(SItoFPInst*) override;
virtual void RunOnInstruction(SliceInst*) override;
virtual void RunOnInstruction(SoftmaxInst*) override;
virtual void RunOnInstruction(SoftplusInst*) override;
virtual void RunOnInstruction(SoftsignInst*) override;
virtual void RunOnInstruction(SigmoidInst*) override;
virtual void RunOnInstruction(StackInst*) override;
virtual void RunOnInstruction(HardSigmoidInst*) override;
virtual void RunOnInstruction(SinInst*) override;
virtual void RunOnInstruction(SinhInst*) override;
18 changes: 18 additions & 0 deletions lib/target/generic_cpp/concat.cc
Original file line number Diff line number Diff line change
@@ -38,4 +38,22 @@ void GenericCXXCodeGen::RunOnInstruction(ConcatInst* inst) {
EmitODLACall(ret, "odla_Concat", inputs, axis, EmitShape(ret_shape));
}

void GenericCXXCodeGen::RunOnInstruction(StackInst* inst) {
const auto& axis = inst->GetAxis();

CXXValue op0 = ir_mapping_[inst->GetOperand(0)];
CXXValue ret(inst->GetName(), op0.type);

ir_mapping_[*inst] = ret;
const halo::Type& ret_shape = inst->GetResultType();
const auto num = inst->GetNumOfOperands();
std::vector<CXXValue> inputs;
for (size_t i = 0; i < num; ++i) {
const Def& op = inst->GetOperand(i);
CXXValue op_v = ir_mapping_[op];
inputs.push_back(op_v);
}
EmitODLACall(ret, "odla_Stack", inputs, axis, EmitShape(ret_shape));
}

} // namespace halo