@@ -50,7 +50,24 @@ static void NormalizerOperands(const Constant& operand,
50
50
} // end namespace
51
51
52
52
void GenericCXXCodeGen::RunOnInstruction (SliceInst* inst) {
53
- const Def input = inst->GetOperand (0 );
53
+ const Def& input = inst->GetOperand (0 );
54
+ const Def& start = inst->GetOperand (1 );
55
+ const Def& size = inst->GetOperand (2 );
56
+ // auto strides = inst->GetOperand(3); //TODO
57
+
58
+ CXXValue op0 = ir_mapping_[input];
59
+ CXXValue ret (inst->GetName (), op0.type );
60
+ ir_mapping_[*inst] = ret;
61
+
62
+ if (!IsA<Constant>(start) || !IsA<Constant>(size)) {
63
+ auto op1 = ir_mapping_[start];
64
+ auto op2 = ir_mapping_[size];
65
+ // auto op3 = ir_mapping_[strides]; // FIXME
66
+ EmitODLACall (ret, " odla_SliceDynamic" , op0, op1, op2, /* op3,*/
67
+ EmitShape (inst->GetResultType ()));
68
+
69
+ return ;
70
+ }
54
71
size_t dims = input.GetType ().GetNumOfDims ();
55
72
std::unordered_set<int32_t > axes;
56
73
if (inst->GetNumOfOperands () > 4 ) {
@@ -75,7 +92,6 @@ void GenericCXXCodeGen::RunOnInstruction(SliceInst* inst) {
75
92
}
76
93
77
94
std::vector<uint32_t > start_v (dims, 0 );
78
- const Def& start = inst->GetOperand (1 );
79
95
HLCHECK (start.GetType ().GetTotalNumOfElements () ==
80
96
static_cast <int64_t >(axes.size ()));
81
97
HLCHECK (IsA<Constant>(start));
@@ -88,7 +104,6 @@ void GenericCXXCodeGen::RunOnInstruction(SliceInst* inst) {
88
104
89
105
std::vector<uint32_t > size_v (input.GetType ().GetDimSizes ().begin (),
90
106
input.GetType ().GetDimSizes ().end ());
91
- const Def& size = inst->GetOperand (2 );
92
107
HLCHECK (size.GetType ().GetTotalNumOfElements () ==
93
108
static_cast <int64_t >(axes.size ()));
94
109
HLCHECK (IsA<Constant>(size));
@@ -125,12 +140,8 @@ void GenericCXXCodeGen::RunOnInstruction(SliceInst* inst) {
125
140
size_v.begin (), std::plus<uint32_t >());
126
141
}
127
142
128
- CXXValue op0 = ir_mapping_[input];
129
- CXXValue ret (inst->GetName (), op0.type );
130
-
131
143
EmitODLACall (ret, " odla_Slice" , op0, start_v, size_v, strides_v,
132
144
EmitShape (inst->GetResultType ()));
133
- ir_mapping_[*inst] = ret;
134
145
}
135
146
136
147
} // namespace halo
0 commit comments