Skip to content

Commit 05f052c

Browse files
author
Weiming Zhao
committed
Add new ODLA for slice; Temporary
1 parent 6aad491 commit 05f052c

File tree

1 file changed

+18
-7
lines changed

1 file changed

+18
-7
lines changed

Diff for: lib/target/generic_cpp/slice.cc

+18-7
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,24 @@ static void NormalizerOperands(const Constant& operand,
5050
} // end namespace
5151

5252
void GenericCXXCodeGen::RunOnInstruction(SliceInst* inst) {
53-
const Def input = inst->GetOperand(0);
53+
const Def& input = inst->GetOperand(0);
54+
const Def& start = inst->GetOperand(1);
55+
const Def& size = inst->GetOperand(2);
56+
// auto strides = inst->GetOperand(3); //TODO
57+
58+
CXXValue op0 = ir_mapping_[input];
59+
CXXValue ret(inst->GetName(), op0.type);
60+
ir_mapping_[*inst] = ret;
61+
62+
if (!IsA<Constant>(start) || !IsA<Constant>(size)) {
63+
auto op1 = ir_mapping_[start];
64+
auto op2 = ir_mapping_[size];
65+
// auto op3 = ir_mapping_[strides]; // FIXME
66+
EmitODLACall(ret, "odla_SliceDynamic", op0, op1, op2, /*op3,*/
67+
EmitShape(inst->GetResultType()));
68+
69+
return;
70+
}
5471
size_t dims = input.GetType().GetNumOfDims();
5572
std::unordered_set<int32_t> axes;
5673
if (inst->GetNumOfOperands() > 4) {
@@ -75,7 +92,6 @@ void GenericCXXCodeGen::RunOnInstruction(SliceInst* inst) {
7592
}
7693

7794
std::vector<uint32_t> start_v(dims, 0);
78-
const Def& start = inst->GetOperand(1);
7995
HLCHECK(start.GetType().GetTotalNumOfElements() ==
8096
static_cast<int64_t>(axes.size()));
8197
HLCHECK(IsA<Constant>(start));
@@ -88,7 +104,6 @@ void GenericCXXCodeGen::RunOnInstruction(SliceInst* inst) {
88104

89105
std::vector<uint32_t> size_v(input.GetType().GetDimSizes().begin(),
90106
input.GetType().GetDimSizes().end());
91-
const Def& size = inst->GetOperand(2);
92107
HLCHECK(size.GetType().GetTotalNumOfElements() ==
93108
static_cast<int64_t>(axes.size()));
94109
HLCHECK(IsA<Constant>(size));
@@ -125,12 +140,8 @@ void GenericCXXCodeGen::RunOnInstruction(SliceInst* inst) {
125140
size_v.begin(), std::plus<uint32_t>());
126141
}
127142

128-
CXXValue op0 = ir_mapping_[input];
129-
CXXValue ret(inst->GetName(), op0.type);
130-
131143
EmitODLACall(ret, "odla_Slice", op0, start_v, size_v, strides_v,
132144
EmitShape(inst->GetResultType()));
133-
ir_mapping_[*inst] = ret;
134145
}
135146

136147
} // namespace halo

0 commit comments

Comments
 (0)