Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Handle Resize-13 with optional ROI operand #235

Merged
merged 4 commits into from
Mar 24, 2021
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/halo/lib/ir/onnx_convert.td
Original file line number Diff line number Diff line change
@@ -288,6 +288,8 @@ def ONNX_Relu : OpMapping<"Relu", Relu>;

def ONNX_Reshape : OpMapping<"Reshape", Reshape>;

// ONNX-13 supports an optional ROI operand.
// Operands: Input, roi (optional), scales (optional)
def ONNX_Resize : OpMapping<"Resize", Resize> {
let attr_mapping_ = [
AttributeMapping<"", "axes_mask", "-1">,
28 changes: 27 additions & 1 deletion lib/transforms/inst_simplify.cc
Original file line number Diff line number Diff line change
@@ -678,9 +678,35 @@ static Constant* GetPermutedConstant(ConstantBuilder* cb, const Constant* orig,
return cb->CreateConstant(orig->GetName(), shape_type, data.data());
}

static bool IsNullConstant(const Def& op) {
if (IsA<Constant>(op)) {
auto type = DynCast<Constant>(op)->GetResultType();
if (!type.IsScalar() && type.GetTotalNumOfElements() == 0) {
return true;
}
}
return false;
}

std::pair<Def, Def> InstSimplify::RunOnInstruction(ResizeInst* inst) {
Def orig_def{inst, 0};
auto op_shape = inst->GetOperand(1);
// Check if the optional operand is valid or not.
// A null constant can be ignored.
std::vector<Def> valid_operands;
for (const auto& op : inst->GetOperands()) {
if (!IsNullConstant(op)) {
valid_operands.push_back(op);
}
}
if (valid_operands.size() < inst->GetNumOfOperands()) {
IRBuilder builder(inst->GetParent());
builder.SetInsertAfter(inst);
// Remove invalid operands.
auto new_resize = builder.Clone(*inst, valid_operands);
return {orig_def, *new_resize};
}
// Resize with 3 operands are not handled.
HLCHECK(inst->GetNumOfOperands() <= 2);
if (IsA<Instruction>(inst->GetOperand(0))) {
Instruction* op0_inst =
DynCast<Instruction>(inst->GetOperand(0).GetOwner());
18 changes: 11 additions & 7 deletions lib/transforms/type_legalizer.cc
Original file line number Diff line number Diff line change
@@ -919,12 +919,16 @@ static void RunOnInstruction(TransposeInst* inst) {
}

static void RunOnInstruction(ResizeInst* inst) {
HLCHECK(inst->GetNumOfOperands() == 2);
const auto& op1 = inst->GetOperand(1);
if (!IsA<Constant>(op1)) {
auto op_num = inst->GetNumOfOperands();
// If op_num is 3, the operands are (x, ROI, scales).
// If op_num is 2, the operands are (x, scales)
// TODO(unknown): ROI operand is not handled now.
HLCHECK(op_num == 2 || op_num == 3);
const auto& op_scale = inst->GetOperand(op_num - 1);
if (!IsA<Constant>(op_scale)) {
return;
}
const Constant* shape_c = DynCast<Constant>(op1);
const Constant* shape_c = DynCast<Constant>(op_scale);

const auto& input_type = inst->GetOperand(0).GetType();
std::vector<int64_t> new_shape = input_type.GetDimSizes();
@@ -934,11 +938,11 @@ static void RunOnInstruction(ResizeInst* inst) {
continue;
}
int64_t dim = 0;
if (op1.GetType().GetDataType() == DataType::INT64) {
if (op_scale.GetType().GetDataType() == DataType::INT64) {
dim = shape_c->GetData<int64_t>(j++);
} else if (op1.GetType().GetDataType() == DataType::INT32) {
} else if (op_scale.GetType().GetDataType() == DataType::INT32) {
dim = shape_c->GetData<int32_t>(j++);
} else if (op1.GetType().GetDataType() == DataType::FLOAT32) {
} else if (op_scale.GetType().GetDataType() == DataType::FLOAT32) {
HLCHECK(inst->GetExplicitShape() == false);
dim = std::floor(new_shape[i] * shape_c->GetData<float>(j++));
}