Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] dynamic shape & cfg #736

Draft
wants to merge 10 commits into
base: master
Choose a base branch
from
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add shape op to support dynamic shapes
  • Loading branch information
Weiming Zhao committed Dec 11, 2021
commit b3d5cc57648aba3ad9eccd507291151fe8cee894
2 changes: 1 addition & 1 deletion ODLA/include/ODLA/ops/odla_ops_process.h
Original file line number Diff line number Diff line change
@@ -259,8 +259,8 @@ odla_Resize(odla_value input, odla_interpolation_mode interpolation,
the result value is implementation determined.

\param input the input value
\param value_id a unique value id (can be NULL)
\param output_dims the optional output shape (can be undefined)
\param value_id a unique value id (can be NULL)

\return odla_value
*/
6 changes: 6 additions & 0 deletions include/halo/lib/ir/common_instructions.td
Original file line number Diff line number Diff line change
@@ -78,6 +78,12 @@ let cat_ = cat_common in {
let outs_ = [Arg<"The result.", MatchArgType<0> >];
}

def Shape : Inst<"Compute the shape of input."> {
let ins_ = [Arg<"The input.", ArgType<[I8,I16,I32,F16,F32]> >];
let attrs_ = [Attr<"The output date type size", EnumDataType, "data_type", "INT64">];
let outs_ = [Arg<"The result.", ArgType<[I64, I32]> >];
}

def Reshape : Inst<"Reshape the input X1 to create the result with the same"
" number of elements and the shape specified by X2."> {
let ins_ = [Arg<"The input.", ArgType<[I8,I16,I32,F16,F32]> >,
5 changes: 4 additions & 1 deletion include/halo/lib/ir/tf_convert.td
Original file line number Diff line number Diff line change
@@ -44,7 +44,10 @@ def TF_Reshape : TFExtension<"Reshape"> {
let extension_attr_ = [ ExtensionAttr<"shape", IntegerList, "{}"> ];
}

def TF_Shape : TFExtension<"Shape">;
def TF_Shape: OpMapping<"Shape", Shape> {
let attr_mapping_ = [
AttributeMapping<"", "data_type", "INT32">];
}

def TF_SquaredDifference : TFExtension<"SquaredDifference">;

1 change: 1 addition & 0 deletions include/halo/lib/target/generic_cxx/generic_cxx_codegen.h
Original file line number Diff line number Diff line change
@@ -177,6 +177,7 @@ class GenericCXXCodeGen : public CodeGen {
virtual void RunOnInstruction(ReturnInst*) override;
virtual void RunOnInstruction(RNNInst*) override;
virtual void RunOnInstruction(SelectInst*) override;
virtual void RunOnInstruction(ShapeInst*) override;
virtual void RunOnInstruction(ShiftInst*) override;
virtual void RunOnInstruction(ShrinkInst*) override;
virtual void RunOnInstruction(SItoFPInst*) override;
1 change: 1 addition & 0 deletions include/halo/lib/transforms/inst_simplify.h
Original file line number Diff line number Diff line change
@@ -65,6 +65,7 @@ class InstSimplify final : public BasicBlockPass {
static std::pair<Def, Def> RunOnInstruction(ResizeInst* inst);
static std::pair<Def, Def> RunOnInstruction(SelectInst* inst);
static std::pair<Def, Def> RunOnInstruction(SetDiff1DInst* inst);
static std::pair<Def, Def> RunOnInstruction(ShapeInst* inst);
static std::pair<Def, Def> RunOnInstruction(SigmoidInst* inst);
static std::pair<Def, Def> RunOnInstruction(SItoFPInst* inst);
static std::pair<Def, Def> RunOnInstruction(FPtoSIInst* inst);
12 changes: 12 additions & 0 deletions lib/target/generic_cpp/reshape.cc
Original file line number Diff line number Diff line change
@@ -32,4 +32,16 @@ void GenericCXXCodeGen::RunOnInstruction(ReshapeInst* inst) {
ir_mapping_[*inst] = ret;
}

void GenericCXXCodeGen::RunOnInstruction(ShapeInst* inst) {
const Def& input = inst->GetOperand(0);

CXXValue op0 = ir_mapping_[input];

const auto& ret_type = inst->GetResultType();
CXXValue ret(inst->GetName(), op0.type);
EmitODLACall(ret, "odla_Shape", op0, EmitShape(ret_type));

ir_mapping_[*inst] = ret;
}

} // namespace halo
25 changes: 25 additions & 0 deletions lib/transforms/inst_simplify.cc
Original file line number Diff line number Diff line change
@@ -873,6 +873,31 @@ std::pair<Def, Def> InstSimplify::RunOnInstruction(Relu6Inst* inst) {
});
}

std::pair<Def, Def> InstSimplify::RunOnInstruction(ShapeInst* inst) {
const auto& type = inst->GetOperand(0).GetType();

Def orig_def{inst, 0};
if (!type.IsValid() || type.IsDynamicShape() || type.IsDynamicBatch()) {
return {orig_def, orig_def};
}

DataType dt = inst->GetDataType();
ConstantBuilder cb(inst->GetParent()->GetParent());
int64_t rank = type.GetNumOfDims();
if (dt == DataType::INT32) {
std::vector<int32_t> shape;
for (int64_t i : type.GetDimSizes()) {
shape.push_back(static_cast<int>(i));
}
Constant* c = cb.CreateConstant(inst->GetName(), halo::Type{dt, {rank}},
shape.data());
return {orig_def, *c};
}
HLCHECK(dt == DataType::INT64);
Constant* c = cb.CreateConstant(inst->GetName(), halo::Type{dt, {rank}},
type.GetDimSizes());
return {orig_def, *c};
}
std::pair<Def, Def> InstSimplify::RunOnInstruction(SigmoidInst* inst) {
return SinkTranspose(
*inst, [](IRBuilder& builder, const std::string& name, const Def& op) {
22 changes: 0 additions & 22 deletions lib/transforms/tfextension_legalizer.cc
Original file line number Diff line number Diff line change
@@ -261,25 +261,6 @@ static std::vector<Def> ConvertFill(const TFExtensionInst* ext,
return {};
}

static std::vector<Def> ConvertShape(const TFExtensionInst* ext,
IRBuilder* builder) {
auto input = ext->GetOperand(0);
const Type& input_type = input.GetType();
if (!input_type.IsValid()) {
return {};
}
std::vector<int> shape;
for (int64_t i : input_type.GetDimSizes()) {
shape.push_back(static_cast<int>(i));
}
ConstantBuilder cb(ext->GetParent()->GetParent());
Constant* c = cb.CreateConstant(
ext->GetName() + "_shape",
Type{DataType::INT32, {static_cast<int64_t>(input_type.GetNumOfDims())}},
shape.data());
return {*c};
}

static std::vector<Def> ConvertSize(const TFExtensionInst* ext,
IRBuilder* builder) {
const auto& type = ext->GetOperand(0).GetType();
@@ -1089,9 +1070,6 @@ static std::vector<Def> ConvertTFExtension(const TFExtensionInst* tf_inst,
case TFExtOpCode::SIZE: {
return ConvertSize(tf_inst, builder);
}
case TFExtOpCode::SHAPE: {
return ConvertShape(tf_inst, builder);
}
case TFExtOpCode::SPLIT: {
return ConvertSplit(tf_inst, builder);
}
10 changes: 10 additions & 0 deletions lib/transforms/type_legalizer.cc
Original file line number Diff line number Diff line change
@@ -659,6 +659,16 @@ static void RunOnCommonReductionInstruction(T* inst, std::vector<int32_t> axis,
inst->GetResultsTypes()[0] = halo::Type{dt, ret_shape};
}

static void RunOnInstruction(ShapeInst* inst) {
const Type& input_type = inst->GetOperand(0).GetType();

if (!input_type.IsValid()) {
return;
}
int rank = input_type.GetNumOfDims();
inst->GetResultsTypes()[0] = halo::Type{inst->GetDataType(), {rank}};
}

static void RunOnInstruction(ReduceL1Inst* inst) {
RunOnCommonReductionInstruction(inst, inst->GetAxis(), inst->GetKeepDims());
}