@@ -44,9 +44,12 @@ static Constant* RunConstantFoldingOnMathBinary(const std::string& name,
44
44
const Type& ret_type, Def op0,
45
45
Def op1, OpCode opcode,
46
46
KindPredicate pred) {
47
- if (!IsA<Constant>(op0) || !IsA<Constant>(op1) ||
48
- op0.GetType ().GetTotalNumOfElements () !=
49
- op1.GetType ().GetTotalNumOfElements ()) {
47
+ if (!IsA<Constant>(op0) || !IsA<Constant>(op1)) {
48
+ return nullptr ;
49
+ }
50
+ if ((op0.GetType ().GetTotalNumOfElements () !=
51
+ op1.GetType ().GetTotalNumOfElements () &&
52
+ op1.GetType ().GetTotalNumOfElements () != 1 )) {
50
53
return nullptr ;
51
54
}
52
55
if (opcode == OpCode::CMP) {
@@ -58,32 +61,36 @@ static Constant* RunConstantFoldingOnMathBinary(const std::string& name,
58
61
std::swap (op0, op1);
59
62
}
60
63
}
61
- Constant* c_lhs = DynCast<Constant>(op0. GetOwner () );
62
- Constant* c_rhs = DynCast<Constant>(op1. GetOwner () );
64
+ Constant* c_lhs = DynCast<Constant>(op0);
65
+ Constant* c_rhs = DynCast<Constant>(op1);
63
66
size_t num_elements = op0.GetType ().GetTotalNumOfElements ();
64
67
Constant* c_ret = nullptr ;
65
68
ConstantBuilder cb (DynCast<Function>(c_lhs->GetParent ()));
66
69
std::vector<T> ret;
67
70
ret.reserve (num_elements);
71
+ bool rhs_is_scalar = op1.GetType ().GetTotalNumOfElements () == 1 ;
68
72
69
73
switch (opcode) {
70
74
case OpCode::ADD: {
71
75
for (size_t i = 0 ; i < num_elements; ++i) {
72
- ret.push_back (c_lhs->GetData <T>(i) + c_rhs->GetData <T>(i));
76
+ ret.push_back (c_lhs->GetData <T>(i) +
77
+ c_rhs->GetData <T>(rhs_is_scalar ? 0 : i));
73
78
}
74
79
c_ret = cb.CreateConstant (name, ret_type, ret.data ());
75
80
break ;
76
81
}
77
82
case OpCode::MUL: {
78
83
for (size_t i = 0 ; i < num_elements; ++i) {
79
- ret.push_back (c_lhs->GetData <T>(i) * c_rhs->GetData <T>(i));
84
+ ret.push_back (c_lhs->GetData <T>(i) *
85
+ c_rhs->GetData <T>(rhs_is_scalar ? 0 : i));
80
86
}
81
87
c_ret = cb.CreateConstant (name, ret_type, ret.data ());
82
88
break ;
83
89
}
84
90
case OpCode::DIV: {
85
91
for (size_t i = 0 ; i < num_elements; ++i) {
86
- ret.push_back (c_lhs->GetData <T>(i) / c_rhs->GetData <T>(i));
92
+ ret.push_back (c_lhs->GetData <T>(i) /
93
+ c_rhs->GetData <T>(rhs_is_scalar ? 0 : i));
87
94
}
88
95
c_ret = cb.CreateConstant (name, ret_type, ret.data ());
89
96
break ;
@@ -93,7 +100,8 @@ static Constant* RunConstantFoldingOnMathBinary(const std::string& name,
93
100
switch (pred) {
94
101
case KindPredicate::LT: {
95
102
for (size_t i = 0 ; i < num_elements; ++i) {
96
- if (c_lhs->GetData <T>(i) < c_rhs->GetData <T>(i)) {
103
+ if (c_lhs->GetData <T>(i) <
104
+ c_rhs->GetData <T>(rhs_is_scalar ? 0 : i)) {
97
105
ret.push_back (1 );
98
106
} else {
99
107
ret.push_back (0 );
@@ -260,39 +268,14 @@ static std::pair<Def, Def> RunOnMathBinaryInstruction(Instruction* binary_inst,
260
268
const auto & op0_type = op0.GetType ();
261
269
const auto & op1_type = op1.GetType ();
262
270
OpCode opcode = binary_inst->GetOpCode ();
263
- /*
264
- // Handle scalar constant
265
- if (IsA<Constant>(op1.GetOwner())) {
266
- Constant* c_op1 = DynCast<Constant>(op1.GetOwner());
267
- Type ret_type = binary_inst->GetResultsTypes()[0];
268
- HLCHECK(ret_type.IsValid());
269
- if (c_op1->IsScalarZero()) {
270
- if (opcode == OpCode::ADD) {
271
- return {orig_def, op0};
272
- }
273
- if (opcode == OpCode::MUL) {
274
- Constant* c_zero =
275
- cb.SplatConstantZero(binary_inst->GetName(), ret_type);
276
- return {orig_def, *c_zero};
277
- }
278
- }
279
- if (c_op1->IsScalarOne()) {
280
- if (opcode == OpCode::MUL) {
281
- return {orig_def, op0};
282
- }
283
- }
284
- }*/
285
271
286
- const int64_t folding_threshold = 10 ;
287
- // Both operands are constant, do constant folding
288
- if (IsA<Constant>(op0) && IsA<Constant>(op1) &&
289
- op0_type.GetTotalNumOfElements () == op1_type.GetTotalNumOfElements () &&
290
- op0_type.GetTotalNumOfElements () < folding_threshold) {
291
- Type ret_type = binary_inst->GetResultsTypes ()[0 ];
292
- HLCHECK (ret_type.IsValid ());
272
+ // Try constant folding
273
+ const auto & ret_type = binary_inst->GetResultType ();
274
+
275
+ if (ret_type.IsValid ()) {
293
276
KindPredicate pred = KindPredicate::INVALID;
294
277
if (opcode == OpCode::CMP) {
295
- pred = static_cast <CmpInst* >(binary_inst)->GetPredicator (); // NOLINT
278
+ pred = DynCast <CmpInst>(binary_inst)->GetPredicator ();
296
279
}
297
280
if (has_swapped) {
298
281
std::swap (op0, op1);
@@ -301,19 +284,19 @@ static std::pair<Def, Def> RunOnMathBinaryInstruction(Instruction* binary_inst,
301
284
switch (op0_type.GetDataType ()) {
302
285
case DataType::INT32: {
303
286
c_ret = RunConstantFoldingOnMathBinary<int >(
304
- binary_inst->GetName () + " _folding " , ret_type, op0, op1, opcode,
287
+ binary_inst->GetName () + " _folded " , ret_type, op0, op1, opcode,
305
288
pred);
306
289
break ;
307
290
}
308
291
case DataType::INT64: {
309
292
c_ret = RunConstantFoldingOnMathBinary<int64_t >(
310
- binary_inst->GetName () + " _folding " , ret_type, op0, op1, opcode,
293
+ binary_inst->GetName () + " _folded " , ret_type, op0, op1, opcode,
311
294
pred);
312
295
break ;
313
296
}
314
297
case DataType::FLOAT32: {
315
298
c_ret = RunConstantFoldingOnMathBinary<float >(
316
- binary_inst->GetName () + " _folding " , ret_type, op0, op1, opcode,
299
+ binary_inst->GetName () + " _folded " , ret_type, op0, op1, opcode,
317
300
pred);
318
301
break ;
319
302
}
@@ -323,7 +306,6 @@ static std::pair<Def, Def> RunOnMathBinaryInstruction(Instruction* binary_inst,
323
306
if (c_ret != nullptr ) {
324
307
return {orig_def, *c_ret};
325
308
}
326
- return {orig_def, orig_def};
327
309
}
328
310
329
311
// Do offline broadcasting.
0 commit comments