From 78256eb27df903949f33a1a30019d701e4d9e14f Mon Sep 17 00:00:00 2001 From: "xuhongyao.xhy" Date: Mon, 22 Mar 2021 06:31:51 +0000 Subject: [PATCH 1/2] dnnl resize support nhwc --- ODLA/platforms/dnnl/odla_dnnl.cc | 23 ++++++++++++++++++++--- lib/transforms/inst_simplify.cc | 2 +- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/ODLA/platforms/dnnl/odla_dnnl.cc b/ODLA/platforms/dnnl/odla_dnnl.cc index 6b9f94867..f8be5aced 100644 --- a/ODLA/platforms/dnnl/odla_dnnl.cc +++ b/ODLA/platforms/dnnl/odla_dnnl.cc @@ -1348,11 +1348,28 @@ odla_value odla_Resize(odla_value input, odla_interpolation_mode interpolation, auto input_md = input->mem.get_desc(); auto dt = input->mem.get_desc().data_type(); - auto ret_md = dnnl::memory::desc(getDims(output_dims), dt, - dnnl::memory::format_tag::nchw); + auto format_tag = dnnl::memory::format_tag::nchw; + + float scale_h; + float scale_w; + std::vector scales = {1.0f, 1.0f, 1.0f, 1.0f}; + if (axes_mask == -1) { + scale_h = 1.0f * output_dims.dims[1] / input_md.dims()[1]; + scale_w = 1.0f * output_dims.dims[2] / input_md.dims()[2]; + scales[1] = scale_h; + scales[2] = scale_w; + format_tag = dnnl::memory::format_tag::nhwc; + } else { + scale_h = 1.0f * output_dims.dims[2] / input_md.dims()[2]; + scale_w = 1.0f * output_dims.dims[3] / input_md.dims()[3]; + scales[2] = scale_h; + scales[3] = scale_w; + } + + auto ret_md = dnnl::memory::desc(getDims(output_dims), dt,format_tag); auto op_desc = dnnl::resampling_forward::desc( - dnnl::prop_kind::forward_inference, algo, input_md, ret_md); + dnnl::prop_kind::forward_inference, algo, scales, input_md); auto pd = dnnl::resampling_forward::primitive_desc(op_desc, g_comp->eng); auto prim = dnnl::resampling_forward(pd); diff --git a/lib/transforms/inst_simplify.cc b/lib/transforms/inst_simplify.cc index e16c1cf96..a6f39f927 100644 --- a/lib/transforms/inst_simplify.cc +++ b/lib/transforms/inst_simplify.cc @@ -709,12 +709,12 @@ std::pair InstSimplify::RunOnInstruction(ResizeInst* inst) { } new_shape->SetName(inst->GetName() + "_resize_shape"); - return SinkTranspose( *inst, [new_shape, inst](IRBuilder& builder, const std::string& name, const Def& op) { auto new_inst = builder.CreateResize(name, {op, *new_shape}); new_inst->CopyAttrsFrom(*inst); + new_inst->SetAxesMask(-1); return new_inst; }); } From 9e5c826c6b7973ac86592cd872bda1982db37922 Mon Sep 17 00:00:00 2001 From: "xuhongyao.xhy" Date: Mon, 22 Mar 2021 07:43:02 +0000 Subject: [PATCH 2/2] format code --- ODLA/platforms/dnnl/odla_dnnl.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ODLA/platforms/dnnl/odla_dnnl.cc b/ODLA/platforms/dnnl/odla_dnnl.cc index f8be5aced..52838e11f 100644 --- a/ODLA/platforms/dnnl/odla_dnnl.cc +++ b/ODLA/platforms/dnnl/odla_dnnl.cc @@ -1366,7 +1366,7 @@ odla_value odla_Resize(odla_value input, odla_interpolation_mode interpolation, scales[3] = scale_w; } - auto ret_md = dnnl::memory::desc(getDims(output_dims), dt,format_tag); + auto ret_md = dnnl::memory::desc(getDims(output_dims), dt, format_tag); auto op_desc = dnnl::resampling_forward::desc( dnnl::prop_kind::forward_inference, algo, scales, input_md);