@@ -44,9 +44,12 @@ static Constant* RunConstantFoldingOnMathBinary(const std::string& name,
44
44
const Type& ret_type, Def op0,
45
45
Def op1, OpCode opcode,
46
46
KindPredicate pred) {
47
- if (!IsA<Constant>(op0) || !IsA<Constant>(op1) ||
48
- op0.GetType ().GetTotalNumOfElements () !=
49
- op1.GetType ().GetTotalNumOfElements ()) {
47
+ if (!IsA<Constant>(op0) || !IsA<Constant>(op1)) {
48
+ return nullptr ;
49
+ }
50
+ if ((op0.GetType ().GetTotalNumOfElements () !=
51
+ op1.GetType ().GetTotalNumOfElements () &&
52
+ op1.GetType ().GetTotalNumOfElements () != 1 )) {
50
53
return nullptr ;
51
54
}
52
55
if (opcode == OpCode::CMP) {
@@ -58,32 +61,36 @@ static Constant* RunConstantFoldingOnMathBinary(const std::string& name,
58
61
std::swap (op0, op1);
59
62
}
60
63
}
61
- Constant* c_lhs = DynCast<Constant>(op0. GetOwner () );
62
- Constant* c_rhs = DynCast<Constant>(op1. GetOwner () );
64
+ Constant* c_lhs = DynCast<Constant>(op0);
65
+ Constant* c_rhs = DynCast<Constant>(op1);
63
66
size_t num_elements = op0.GetType ().GetTotalNumOfElements ();
64
67
Constant* c_ret = nullptr ;
65
68
ConstantBuilder cb (DynCast<Function>(c_lhs->GetParent ()));
66
69
std::vector<T> ret;
67
70
ret.reserve (num_elements);
71
+ bool rhs_is_scalar = op1.GetType ().GetTotalNumOfElements () == 1 ;
68
72
69
73
switch (opcode) {
70
74
case OpCode::ADD: {
71
75
for (size_t i = 0 ; i < num_elements; ++i) {
72
- ret.push_back (c_lhs->GetData <T>(i) + c_rhs->GetData <T>(i));
76
+ ret.push_back (c_lhs->GetData <T>(i) +
77
+ c_rhs->GetData <T>(rhs_is_scalar ? 0 : i));
73
78
}
74
79
c_ret = cb.CreateConstant (name, ret_type, ret.data ());
75
80
break ;
76
81
}
77
82
case OpCode::MUL: {
78
83
for (size_t i = 0 ; i < num_elements; ++i) {
79
- ret.push_back (c_lhs->GetData <T>(i) * c_rhs->GetData <T>(i));
84
+ ret.push_back (c_lhs->GetData <T>(i) *
85
+ c_rhs->GetData <T>(rhs_is_scalar ? 0 : i));
80
86
}
81
87
c_ret = cb.CreateConstant (name, ret_type, ret.data ());
82
88
break ;
83
89
}
84
90
case OpCode::DIV: {
85
91
for (size_t i = 0 ; i < num_elements; ++i) {
86
- ret.push_back (c_lhs->GetData <T>(i) / c_rhs->GetData <T>(i));
92
+ ret.push_back (c_lhs->GetData <T>(i) /
93
+ c_rhs->GetData <T>(rhs_is_scalar ? 0 : i));
87
94
}
88
95
c_ret = cb.CreateConstant (name, ret_type, ret.data ());
89
96
break ;
@@ -93,7 +100,8 @@ static Constant* RunConstantFoldingOnMathBinary(const std::string& name,
93
100
switch (pred) {
94
101
case KindPredicate::LT: {
95
102
for (size_t i = 0 ; i < num_elements; ++i) {
96
- if (c_lhs->GetData <T>(i) < c_rhs->GetData <T>(i)) {
103
+ if (c_lhs->GetData <T>(i) <
104
+ c_rhs->GetData <T>(rhs_is_scalar ? 0 : i)) {
97
105
ret.push_back (1 );
98
106
} else {
99
107
ret.push_back (0 );
@@ -283,16 +291,13 @@ static std::pair<Def, Def> RunOnMathBinaryInstruction(Instruction* binary_inst,
283
291
}
284
292
}*/
285
293
286
- const int64_t folding_threshold = 10 ;
287
- // Both operands are constant, do constant folding
288
- if (IsA<Constant>(op0) && IsA<Constant>(op1) &&
289
- op0_type.GetTotalNumOfElements () == op1_type.GetTotalNumOfElements () &&
290
- op0_type.GetTotalNumOfElements () < folding_threshold) {
291
- Type ret_type = binary_inst->GetResultsTypes ()[0 ];
292
- HLCHECK (ret_type.IsValid ());
294
+ // Try constant folding
295
+ const auto & ret_type = binary_inst->GetResultType ();
296
+
297
+ if (ret_type.IsValid ()) {
293
298
KindPredicate pred = KindPredicate::INVALID;
294
299
if (opcode == OpCode::CMP) {
295
- pred = static_cast <CmpInst* >(binary_inst)->GetPredicator (); // NOLINT
300
+ pred = DynCast <CmpInst>(binary_inst)->GetPredicator ();
296
301
}
297
302
if (has_swapped) {
298
303
std::swap (op0, op1);
@@ -301,19 +306,19 @@ static std::pair<Def, Def> RunOnMathBinaryInstruction(Instruction* binary_inst,
301
306
switch (op0_type.GetDataType ()) {
302
307
case DataType::INT32: {
303
308
c_ret = RunConstantFoldingOnMathBinary<int >(
304
- binary_inst->GetName () + " _folding " , ret_type, op0, op1, opcode,
309
+ binary_inst->GetName () + " _folded " , ret_type, op0, op1, opcode,
305
310
pred);
306
311
break ;
307
312
}
308
313
case DataType::INT64: {
309
314
c_ret = RunConstantFoldingOnMathBinary<int64_t >(
310
- binary_inst->GetName () + " _folding " , ret_type, op0, op1, opcode,
315
+ binary_inst->GetName () + " _folded " , ret_type, op0, op1, opcode,
311
316
pred);
312
317
break ;
313
318
}
314
319
case DataType::FLOAT32: {
315
320
c_ret = RunConstantFoldingOnMathBinary<float >(
316
- binary_inst->GetName () + " _folding " , ret_type, op0, op1, opcode,
321
+ binary_inst->GetName () + " _folded " , ret_type, op0, op1, opcode,
317
322
pred);
318
323
break ;
319
324
}
0 commit comments