Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] dynamic shape & cfg from pr736 #743

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Prev Previous commit
fix:slice_op parameters' type-uint2int
  • Loading branch information
littlefatfat committed Dec 16, 2021
commit a6f8250f6fd52e747458f720e766b269d5031d32
4 changes: 2 additions & 2 deletions ODLA/include/ODLA/ops/odla_ops_process.h
Original file line number Diff line number Diff line change
@@ -282,8 +282,8 @@ odla_Shape(odla_value input, odla_value_shape output_dims,
\return odla_value
*/
extern ODLA_API_EXPORT odla_value ODLA_API_CALL
odla_Slice(odla_value input, const odla_uint32* start, const odla_uint32* end,
const odla_uint32* stride, odla_value_shape output_dims,
odla_Slice(odla_value input, const odla_int32* start, const odla_int32* end,
const odla_int32* stride, odla_value_shape output_dims,
const odla_value_id value_id);

//! \brief Extract a dynamic slice
8 changes: 4 additions & 4 deletions ODLA/platforms/dnnl/odla_dnnl.cc
Original file line number Diff line number Diff line change
@@ -1834,8 +1834,8 @@ odla_value odla_Erf(odla_value input, const odla_value_id value_id) {

static void strided_slice(const void* src, int elem_size,
const odla_value_shape& input_dims,
const odla_uint32* start, const odla_uint32* end,
const odla_uint32* strides, void* dst,
const odla_int32* start, const odla_int32* end,
const odla_int32* strides, void* dst,
const odla_value_shape& output_dims) {
int64_t dst_elems = GetTotalElements(output_dims);
int dims = input_dims.size;
@@ -1871,8 +1871,8 @@ static void strided_slice(const void* src, int elem_size,
}
}

odla_value odla_Slice(odla_value input, const odla_uint32* start,
const odla_uint32* end, const odla_uint32* strides,
odla_value odla_Slice(odla_value input, const odla_int32* start,
const odla_int32* end, const odla_int32* strides,
odla_value_shape output_dims, const odla_value_id id) {
const auto& input_dims = input->shape;
int dims = input_dims.size;
4 changes: 2 additions & 2 deletions ODLA/platforms/odla_popart/odla_ops.cc
Original file line number Diff line number Diff line change
@@ -359,8 +359,8 @@ odla_value odla_ReduceMean(odla_value input, odla_size_t num_of_axes,
return result;
}

odla_value odla_Slice(odla_value input, const odla_uint32* start,
const odla_uint32* end, const odla_uint32* stride,
odla_value odla_Slice(odla_value input, const odla_int32* start,
const odla_int32* end, const odla_int32* stride,
odla_value_shape output_dims, const odla_value_id id) {
const auto& name = id ? std::string(reinterpret_cast<const char*>(id)) : "";

4 changes: 2 additions & 2 deletions ODLA/platforms/odla_profiler.cc
Original file line number Diff line number Diff line change
@@ -522,8 +522,8 @@ odla_value odla_Rsqrt(odla_value input, const odla_value_id id) {
}

static constexpr const char fn_slice[] = "odla_Slice";
odla_value odla_Slice(odla_value input, const odla_uint32* start,
const odla_uint32* end, const odla_uint32* strides,
odla_value odla_Slice(odla_value input, const odla_int32* start,
const odla_int32* end, const odla_int32* strides,
odla_value_shape output_dims, const odla_value_id id) {
return profile<fn_slice>(input, start, end, strides, output_dims, id);
}
4 changes: 2 additions & 2 deletions ODLA/platforms/tensorrt/odla_tensorrt.cc
Original file line number Diff line number Diff line change
@@ -1838,8 +1838,8 @@ odla_value odla_Gather(odla_value input, const odla_value indices,
return CreateValue(gather, {input->type.element_type, output_dims}, id);
}

odla_value odla_Slice(odla_value input, const odla_uint32* start,
const odla_uint32* end, const odla_uint32* stride,
odla_value odla_Slice(odla_value input, const odla_int32* start,
const odla_int32* end, const odla_int32* stride,
odla_value_shape output_dims, const odla_value_id id) {
odla_value_shape start_dims, stride_dims;
const auto& dims = input->type.shape;
21 changes: 10 additions & 11 deletions lib/target/generic_cpp/slice.cc
Original file line number Diff line number Diff line change
@@ -33,15 +33,14 @@ namespace {
template <typename T>
static void NormalizerOperands(const Constant& operand,
const std::unordered_set<int32_t>& axes,
const size_t dims,
std::vector<uint32_t>* value) {
const size_t dims, std::vector<int32_t>* value) {
bool onnx_mode = axes.size() != dims;
for (size_t i = 0, j = 0; i < dims; ++i) {
if (axes.count(i) != 0) {
(*value)[i] = static_cast<uint32_t>(operand.GetData<T>(j++));
(*value)[i] = static_cast<int32_t>(operand.GetData<T>(j++));
} else {
if (!onnx_mode) {
(*value)[i] = static_cast<uint32_t>(operand.GetData<T>(i));
(*value)[i] = static_cast<int32_t>(operand.GetData<T>(i));
}
}
}
@@ -91,7 +90,7 @@ void GenericCXXCodeGen::RunOnInstruction(SliceInst* inst) {
}
}

std::vector<uint32_t> start_v(dims, 0);
std::vector<int32_t> start_v(dims, 0);
HLCHECK(start.GetType().GetTotalNumOfElements() ==
static_cast<int64_t>(axes.size()));
HLCHECK(IsA<Constant>(start));
@@ -102,8 +101,8 @@ void GenericCXXCodeGen::RunOnInstruction(SliceInst* inst) {
NormalizerOperands<int64_t>(*start_c, axes, dims, &start_v);
}

std::vector<uint32_t> size_v(input.GetType().GetDimSizes().begin(),
input.GetType().GetDimSizes().end());
std::vector<int32_t> size_v(input.GetType().GetDimSizes().begin(),
input.GetType().GetDimSizes().end());
HLCHECK(size.GetType().GetTotalNumOfElements() ==
static_cast<int64_t>(axes.size()));
HLCHECK(IsA<Constant>(size));
@@ -114,7 +113,7 @@ void GenericCXXCodeGen::RunOnInstruction(SliceInst* inst) {
NormalizerOperands<int64_t>(*size_c, axes, dims, &size_v);
}

std::vector<uint32_t> strides_v(dims, 1);
std::vector<int32_t> strides_v(dims, 1);
if (inst->GetNumOfOperands() > 3) {
const Def& strides = inst->GetOperand(3);
HLCHECK(IsA<Constant>(strides));
@@ -129,15 +128,15 @@ void GenericCXXCodeGen::RunOnInstruction(SliceInst* inst) {

// stride is provided, calculate ends = starts + sizes * strides
std::for_each(strides_v.begin(), strides_v.end(),
[=](uint32_t& s) { s = s >= 0 ? s : dims + s; });
[=](int32_t& s) { s = s >= 0 ? s : dims + s; });
std::transform(strides_v.begin(), strides_v.end(), size_v.begin(),
size_v.begin(), std::multiplies<uint32_t>());
std::transform(start_v.begin(), start_v.end(), size_v.begin(),
size_v.begin(), std::plus<uint32_t>());
size_v.begin(), std::plus<int32_t>());
} else {
// stride is omitted, set to [1,1,...,1], calculate ends = starts + sizes
std::transform(size_v.begin(), size_v.end(), start_v.begin(),
size_v.begin(), std::plus<uint32_t>());
size_v.begin(), std::plus<int32_t>());
}

EmitODLACall(ret, "odla_Slice", op0, start_v, size_v, strides_v,