Skip to content

Commit 505b9d4

Browse files
Weiming Zhaoweimingzha0
Weiming Zhao
authored andcommitted
[ONNX] Handle Resize-13 with optional ROI operand
This patch handles a special case of Resize-13: the ROI operand is null. If the optional ROI operand is not null, it will assert.
1 parent 789f53f commit 505b9d4

File tree

3 files changed

+40
-8
lines changed

3 files changed

+40
-8
lines changed

include/halo/lib/ir/onnx_convert.td

+2
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,8 @@ def ONNX_Relu : OpMapping<"Relu", Relu>;
288288

289289
def ONNX_Reshape : OpMapping<"Reshape", Reshape>;
290290

291+
// ONNX-13 supports an optional ROI operand.
292+
// Operands: Input, roi (optional), scales (optional)
291293
def ONNX_Resize : OpMapping<"Resize", Resize> {
292294
let attr_mapping_ = [
293295
AttributeMapping<"", "axes_mask", "-1">,

lib/transforms/inst_simplify.cc

+27-1
Original file line numberDiff line numberDiff line change
@@ -678,9 +678,35 @@ static Constant* GetPermutedConstant(ConstantBuilder* cb, const Constant* orig,
678678
return cb->CreateConstant(orig->GetName(), shape_type, data.data());
679679
}
680680

681+
static bool IsNullConstant(const Def& op) {
682+
if (IsA<Constant>(op)) {
683+
auto type = DynCast<Constant>(op)->GetResultType();
684+
if (!type.IsScalar() && type.GetTotalNumOfElements() == 0) {
685+
return true;
686+
}
687+
}
688+
return false;
689+
}
690+
681691
std::pair<Def, Def> InstSimplify::RunOnInstruction(ResizeInst* inst) {
682692
Def orig_def{inst, 0};
683-
auto op_shape = inst->GetOperand(1);
693+
// Check if the optional operand is valid or not.
694+
// A null constant can be ignored.
695+
std::vector<Def> valid_operands;
696+
for (const auto& op : inst->GetOperands()) {
697+
if (!IsNullConstant(op)) {
698+
valid_operands.push_back(op);
699+
}
700+
}
701+
if (valid_operands.size() < inst->GetNumOfOperands()) {
702+
IRBuilder builder(inst->GetParent());
703+
builder.SetInsertAfter(inst);
704+
// Remove invalid operands.
705+
auto new_resize = builder.Clone(*inst, valid_operands);
706+
return {orig_def, *new_resize};
707+
}
708+
// Resize with 3 operands are not handled.
709+
HLCHECK(inst->GetNumOfOperands() <= 2);
684710
if (IsA<Instruction>(inst->GetOperand(0))) {
685711
Instruction* op0_inst =
686712
DynCast<Instruction>(inst->GetOperand(0).GetOwner());

lib/transforms/type_legalizer.cc

+11-7
Original file line numberDiff line numberDiff line change
@@ -919,12 +919,16 @@ static void RunOnInstruction(TransposeInst* inst) {
919919
}
920920

921921
static void RunOnInstruction(ResizeInst* inst) {
922-
HLCHECK(inst->GetNumOfOperands() == 2);
923-
const auto& op1 = inst->GetOperand(1);
924-
if (!IsA<Constant>(op1)) {
922+
auto op_num = inst->GetNumOfOperands();
923+
// If op_num is 3, the operands are (x, ROI, scales).
924+
// If op_num is 2, the operands are (x, scales)
925+
// TODO(unknown): ROI operand is not handled now.
926+
HLCHECK(op_num == 2 || op_num == 3);
927+
const auto& op_scale = inst->GetOperand(op_num - 1);
928+
if (!IsA<Constant>(op_scale)) {
925929
return;
926930
}
927-
const Constant* shape_c = DynCast<Constant>(op1);
931+
const Constant* shape_c = DynCast<Constant>(op_scale);
928932

929933
const auto& input_type = inst->GetOperand(0).GetType();
930934
std::vector<int64_t> new_shape = input_type.GetDimSizes();
@@ -934,11 +938,11 @@ static void RunOnInstruction(ResizeInst* inst) {
934938
continue;
935939
}
936940
int64_t dim = 0;
937-
if (op1.GetType().GetDataType() == DataType::INT64) {
941+
if (op_scale.GetType().GetDataType() == DataType::INT64) {
938942
dim = shape_c->GetData<int64_t>(j++);
939-
} else if (op1.GetType().GetDataType() == DataType::INT32) {
943+
} else if (op_scale.GetType().GetDataType() == DataType::INT32) {
940944
dim = shape_c->GetData<int32_t>(j++);
941-
} else if (op1.GetType().GetDataType() == DataType::FLOAT32) {
945+
} else if (op_scale.GetType().GetDataType() == DataType::FLOAT32) {
942946
HLCHECK(inst->GetExplicitShape() == false);
943947
dim = std::floor(new_shape[i] * shape_c->GetData<float>(j++));
944948
}

0 commit comments

Comments
 (0)