Skip to content

Commit 788c7f9

Browse files
author
Weiming Zhao
committed
[Opt] Math constant folding for scalars
[constant folding]
1 parent a9db5ea commit 788c7f9

File tree

1 file changed

+25
-20
lines changed

1 file changed

+25
-20
lines changed

lib/transforms/inst_simplify.cc

+25-20
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,12 @@ static Constant* RunConstantFoldingOnMathBinary(const std::string& name,
4444
const Type& ret_type, Def op0,
4545
Def op1, OpCode opcode,
4646
KindPredicate pred) {
47-
if (!IsA<Constant>(op0) || !IsA<Constant>(op1) ||
48-
op0.GetType().GetTotalNumOfElements() !=
49-
op1.GetType().GetTotalNumOfElements()) {
47+
if (!IsA<Constant>(op0) || !IsA<Constant>(op1)) {
48+
return nullptr;
49+
}
50+
if ((op0.GetType().GetTotalNumOfElements() !=
51+
op1.GetType().GetTotalNumOfElements() &&
52+
op1.GetType().GetTotalNumOfElements() != 1)) {
5053
return nullptr;
5154
}
5255
if (opcode == OpCode::CMP) {
@@ -58,32 +61,36 @@ static Constant* RunConstantFoldingOnMathBinary(const std::string& name,
5861
std::swap(op0, op1);
5962
}
6063
}
61-
Constant* c_lhs = DynCast<Constant>(op0.GetOwner());
62-
Constant* c_rhs = DynCast<Constant>(op1.GetOwner());
64+
Constant* c_lhs = DynCast<Constant>(op0);
65+
Constant* c_rhs = DynCast<Constant>(op1);
6366
size_t num_elements = op0.GetType().GetTotalNumOfElements();
6467
Constant* c_ret = nullptr;
6568
ConstantBuilder cb(DynCast<Function>(c_lhs->GetParent()));
6669
std::vector<T> ret;
6770
ret.reserve(num_elements);
71+
bool rhs_is_scalar = op1.GetType().GetTotalNumOfElements() == 1;
6872

6973
switch (opcode) {
7074
case OpCode::ADD: {
7175
for (size_t i = 0; i < num_elements; ++i) {
72-
ret.push_back(c_lhs->GetData<T>(i) + c_rhs->GetData<T>(i));
76+
ret.push_back(c_lhs->GetData<T>(i) +
77+
c_rhs->GetData<T>(rhs_is_scalar ? 0 : i));
7378
}
7479
c_ret = cb.CreateConstant(name, ret_type, ret.data());
7580
break;
7681
}
7782
case OpCode::MUL: {
7883
for (size_t i = 0; i < num_elements; ++i) {
79-
ret.push_back(c_lhs->GetData<T>(i) * c_rhs->GetData<T>(i));
84+
ret.push_back(c_lhs->GetData<T>(i) *
85+
c_rhs->GetData<T>(rhs_is_scalar ? 0 : i));
8086
}
8187
c_ret = cb.CreateConstant(name, ret_type, ret.data());
8288
break;
8389
}
8490
case OpCode::DIV: {
8591
for (size_t i = 0; i < num_elements; ++i) {
86-
ret.push_back(c_lhs->GetData<T>(i) / c_rhs->GetData<T>(i));
92+
ret.push_back(c_lhs->GetData<T>(i) /
93+
c_rhs->GetData<T>(rhs_is_scalar ? 0 : i));
8794
}
8895
c_ret = cb.CreateConstant(name, ret_type, ret.data());
8996
break;
@@ -93,7 +100,8 @@ static Constant* RunConstantFoldingOnMathBinary(const std::string& name,
93100
switch (pred) {
94101
case KindPredicate::LT: {
95102
for (size_t i = 0; i < num_elements; ++i) {
96-
if (c_lhs->GetData<T>(i) < c_rhs->GetData<T>(i)) {
103+
if (c_lhs->GetData<T>(i) <
104+
c_rhs->GetData<T>(rhs_is_scalar ? 0 : i)) {
97105
ret.push_back(1);
98106
} else {
99107
ret.push_back(0);
@@ -283,16 +291,13 @@ static std::pair<Def, Def> RunOnMathBinaryInstruction(Instruction* binary_inst,
283291
}
284292
}*/
285293

286-
const int64_t folding_threshold = 10;
287-
// Both operands are constant, do constant folding
288-
if (IsA<Constant>(op0) && IsA<Constant>(op1) &&
289-
op0_type.GetTotalNumOfElements() == op1_type.GetTotalNumOfElements() &&
290-
op0_type.GetTotalNumOfElements() < folding_threshold) {
291-
Type ret_type = binary_inst->GetResultsTypes()[0];
292-
HLCHECK(ret_type.IsValid());
294+
// Try constant folding
295+
const auto& ret_type = binary_inst->GetResultType();
296+
297+
if (ret_type.IsValid()) {
293298
KindPredicate pred = KindPredicate::INVALID;
294299
if (opcode == OpCode::CMP) {
295-
pred = static_cast<CmpInst*>(binary_inst)->GetPredicator(); // NOLINT
300+
pred = DynCast<CmpInst>(binary_inst)->GetPredicator();
296301
}
297302
if (has_swapped) {
298303
std::swap(op0, op1);
@@ -301,19 +306,19 @@ static std::pair<Def, Def> RunOnMathBinaryInstruction(Instruction* binary_inst,
301306
switch (op0_type.GetDataType()) {
302307
case DataType::INT32: {
303308
c_ret = RunConstantFoldingOnMathBinary<int>(
304-
binary_inst->GetName() + "_folding", ret_type, op0, op1, opcode,
309+
binary_inst->GetName() + "_folded", ret_type, op0, op1, opcode,
305310
pred);
306311
break;
307312
}
308313
case DataType::INT64: {
309314
c_ret = RunConstantFoldingOnMathBinary<int64_t>(
310-
binary_inst->GetName() + "_folding", ret_type, op0, op1, opcode,
315+
binary_inst->GetName() + "_folded", ret_type, op0, op1, opcode,
311316
pred);
312317
break;
313318
}
314319
case DataType::FLOAT32: {
315320
c_ret = RunConstantFoldingOnMathBinary<float>(
316-
binary_inst->GetName() + "_folding", ret_type, op0, op1, opcode,
321+
binary_inst->GetName() + "_folded", ret_type, op0, op1, opcode,
317322
pred);
318323
break;
319324
}

0 commit comments

Comments
 (0)