Skip to content

Commit 037f5a7

Browse files
Weiming Zhaoweimingzha0
Weiming Zhao
authored andcommitted
[Opt] Math constant folding for scalars
Support constant folding for math binary opterations when lhs or rhs is a constant scalar.
1 parent 178cac6 commit 037f5a7

File tree

1 file changed

+25
-43
lines changed

1 file changed

+25
-43
lines changed

lib/transforms/inst_simplify.cc

+25-43
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,12 @@ static Constant* RunConstantFoldingOnMathBinary(const std::string& name,
4444
const Type& ret_type, Def op0,
4545
Def op1, OpCode opcode,
4646
KindPredicate pred) {
47-
if (!IsA<Constant>(op0) || !IsA<Constant>(op1) ||
48-
op0.GetType().GetTotalNumOfElements() !=
49-
op1.GetType().GetTotalNumOfElements()) {
47+
if (!IsA<Constant>(op0) || !IsA<Constant>(op1)) {
48+
return nullptr;
49+
}
50+
if ((op0.GetType().GetTotalNumOfElements() !=
51+
op1.GetType().GetTotalNumOfElements() &&
52+
op1.GetType().GetTotalNumOfElements() != 1)) {
5053
return nullptr;
5154
}
5255
if (opcode == OpCode::CMP) {
@@ -58,32 +61,36 @@ static Constant* RunConstantFoldingOnMathBinary(const std::string& name,
5861
std::swap(op0, op1);
5962
}
6063
}
61-
Constant* c_lhs = DynCast<Constant>(op0.GetOwner());
62-
Constant* c_rhs = DynCast<Constant>(op1.GetOwner());
64+
Constant* c_lhs = DynCast<Constant>(op0);
65+
Constant* c_rhs = DynCast<Constant>(op1);
6366
size_t num_elements = op0.GetType().GetTotalNumOfElements();
6467
Constant* c_ret = nullptr;
6568
ConstantBuilder cb(DynCast<Function>(c_lhs->GetParent()));
6669
std::vector<T> ret;
6770
ret.reserve(num_elements);
71+
bool rhs_is_scalar = op1.GetType().GetTotalNumOfElements() == 1;
6872

6973
switch (opcode) {
7074
case OpCode::ADD: {
7175
for (size_t i = 0; i < num_elements; ++i) {
72-
ret.push_back(c_lhs->GetData<T>(i) + c_rhs->GetData<T>(i));
76+
ret.push_back(c_lhs->GetData<T>(i) +
77+
c_rhs->GetData<T>(rhs_is_scalar ? 0 : i));
7378
}
7479
c_ret = cb.CreateConstant(name, ret_type, ret.data());
7580
break;
7681
}
7782
case OpCode::MUL: {
7883
for (size_t i = 0; i < num_elements; ++i) {
79-
ret.push_back(c_lhs->GetData<T>(i) * c_rhs->GetData<T>(i));
84+
ret.push_back(c_lhs->GetData<T>(i) *
85+
c_rhs->GetData<T>(rhs_is_scalar ? 0 : i));
8086
}
8187
c_ret = cb.CreateConstant(name, ret_type, ret.data());
8288
break;
8389
}
8490
case OpCode::DIV: {
8591
for (size_t i = 0; i < num_elements; ++i) {
86-
ret.push_back(c_lhs->GetData<T>(i) / c_rhs->GetData<T>(i));
92+
ret.push_back(c_lhs->GetData<T>(i) /
93+
c_rhs->GetData<T>(rhs_is_scalar ? 0 : i));
8794
}
8895
c_ret = cb.CreateConstant(name, ret_type, ret.data());
8996
break;
@@ -93,7 +100,8 @@ static Constant* RunConstantFoldingOnMathBinary(const std::string& name,
93100
switch (pred) {
94101
case KindPredicate::LT: {
95102
for (size_t i = 0; i < num_elements; ++i) {
96-
if (c_lhs->GetData<T>(i) < c_rhs->GetData<T>(i)) {
103+
if (c_lhs->GetData<T>(i) <
104+
c_rhs->GetData<T>(rhs_is_scalar ? 0 : i)) {
97105
ret.push_back(1);
98106
} else {
99107
ret.push_back(0);
@@ -260,39 +268,14 @@ static std::pair<Def, Def> RunOnMathBinaryInstruction(Instruction* binary_inst,
260268
const auto& op0_type = op0.GetType();
261269
const auto& op1_type = op1.GetType();
262270
OpCode opcode = binary_inst->GetOpCode();
263-
/*
264-
// Handle scalar constant
265-
if (IsA<Constant>(op1.GetOwner())) {
266-
Constant* c_op1 = DynCast<Constant>(op1.GetOwner());
267-
Type ret_type = binary_inst->GetResultsTypes()[0];
268-
HLCHECK(ret_type.IsValid());
269-
if (c_op1->IsScalarZero()) {
270-
if (opcode == OpCode::ADD) {
271-
return {orig_def, op0};
272-
}
273-
if (opcode == OpCode::MUL) {
274-
Constant* c_zero =
275-
cb.SplatConstantZero(binary_inst->GetName(), ret_type);
276-
return {orig_def, *c_zero};
277-
}
278-
}
279-
if (c_op1->IsScalarOne()) {
280-
if (opcode == OpCode::MUL) {
281-
return {orig_def, op0};
282-
}
283-
}
284-
}*/
285271

286-
const int64_t folding_threshold = 10;
287-
// Both operands are constant, do constant folding
288-
if (IsA<Constant>(op0) && IsA<Constant>(op1) &&
289-
op0_type.GetTotalNumOfElements() == op1_type.GetTotalNumOfElements() &&
290-
op0_type.GetTotalNumOfElements() < folding_threshold) {
291-
Type ret_type = binary_inst->GetResultsTypes()[0];
292-
HLCHECK(ret_type.IsValid());
272+
// Try constant folding
273+
const auto& ret_type = binary_inst->GetResultType();
274+
275+
if (ret_type.IsValid()) {
293276
KindPredicate pred = KindPredicate::INVALID;
294277
if (opcode == OpCode::CMP) {
295-
pred = static_cast<CmpInst*>(binary_inst)->GetPredicator(); // NOLINT
278+
pred = DynCast<CmpInst>(binary_inst)->GetPredicator();
296279
}
297280
if (has_swapped) {
298281
std::swap(op0, op1);
@@ -301,19 +284,19 @@ static std::pair<Def, Def> RunOnMathBinaryInstruction(Instruction* binary_inst,
301284
switch (op0_type.GetDataType()) {
302285
case DataType::INT32: {
303286
c_ret = RunConstantFoldingOnMathBinary<int>(
304-
binary_inst->GetName() + "_folding", ret_type, op0, op1, opcode,
287+
binary_inst->GetName() + "_folded", ret_type, op0, op1, opcode,
305288
pred);
306289
break;
307290
}
308291
case DataType::INT64: {
309292
c_ret = RunConstantFoldingOnMathBinary<int64_t>(
310-
binary_inst->GetName() + "_folding", ret_type, op0, op1, opcode,
293+
binary_inst->GetName() + "_folded", ret_type, op0, op1, opcode,
311294
pred);
312295
break;
313296
}
314297
case DataType::FLOAT32: {
315298
c_ret = RunConstantFoldingOnMathBinary<float>(
316-
binary_inst->GetName() + "_folding", ret_type, op0, op1, opcode,
299+
binary_inst->GetName() + "_folded", ret_type, op0, op1, opcode,
317300
pred);
318301
break;
319302
}
@@ -323,7 +306,6 @@ static std::pair<Def, Def> RunOnMathBinaryInstruction(Instruction* binary_inst,
323306
if (c_ret != nullptr) {
324307
return {orig_def, *c_ret};
325308
}
326-
return {orig_def, orig_def};
327309
}
328310

329311
// Do offline broadcasting.

0 commit comments

Comments
 (0)