Skip to content

Commit 6aa728d

Browse files
Weiming Zhaoweimingzha0
Weiming Zhao
authored andcommitted
[ODLA/DNNL] Fix batchnorm
1 parent 66a3da3 commit 6aa728d

File tree

3 files changed

+23
-25
lines changed

3 files changed

+23
-25
lines changed

ODLA/platforms/dnnl/odla_dnnl.cc

+21-23
Original file line numberDiff line numberDiff line change
@@ -1090,7 +1090,7 @@ odla_value odla_BatchNormalization(odla_value input,
10901090
odla_value offset, odla_float32 scalar_scale,
10911091
odla_float32 scalar_offset,
10921092
const odla_value_id value_id) {
1093-
dnnl::memory oring_mem;
1093+
dnnl::memory origin_mem;
10941094
dnnl::memory::data_type dtype = input->mem.get_desc().data_type();
10951095
// black list op should convert to fp32
10961096
bool bf16_mode = (dtype == dnnl::memory::data_type::bf16 ||
@@ -1105,7 +1105,7 @@ odla_value odla_BatchNormalization(odla_value input,
11051105
scale->mem = cast_op(scale, dnnl::memory::data_type::f32);
11061106
offset->mem = cast_op(offset, dnnl::memory::data_type::f32);
11071107
}
1108-
oring_mem = input->mem;
1108+
origin_mem = input->mem;
11091109
input->mem = f32_input_mem;
11101110
}
11111111

@@ -1121,25 +1121,23 @@ odla_value odla_BatchNormalization(odla_value input,
11211121
}
11221122

11231123
unsigned channels = input_dims.dims[1];
1124-
dnnl::memory::desc weight_md(dnnl::memory::dims{2, channels}, type,
1125-
dnnl::memory::format_tag::nc);
1126-
dnnl::memory weight_mem = dnnl::memory(weight_md, g_comp->eng);
1127-
1128-
if (scale != nullptr && offset != nullptr) {
1124+
dnnl::memory scale_offset_mem = dnnl::memory();
1125+
if (scale != nullptr || offset != nullptr || scalar_offset != 0.0F ||
1126+
scalar_scale != 1.0F) {
1127+
// make a tensor [scale, bias].
1128+
auto get_value = [channels](odla_value x, float scalar) {
1129+
if (x == nullptr) {
1130+
x = odla_CreateConstant({ODLA_FLOAT32, {2, {1, 1}}}, &scalar,
1131+
nullptr); // FIXME: copy to buf
1132+
}
1133+
return odla_Reshape(x, {2, {1, channels}}, nullptr);
1134+
};
1135+
odla_value s = get_value(scale, scalar_scale);
1136+
odla_value b = get_value(offset, scalar_offset);
11291137
flags |= dnnl::normalization_flags::use_scale_shift;
1130-
auto scale_md =
1131-
dnnl::memory::desc({1, channels}, type, dnnl::memory::format_tag::nc);
1132-
auto scale_mem =
1133-
dnnl::memory(scale_md, g_comp->eng, scale->mem.get_data_handle());
1134-
auto offset_md =
1135-
dnnl::memory::desc({1, channels}, type, dnnl::memory::format_tag::nc);
1136-
auto c_pd = dnnl::concat::primitive_desc(
1137-
weight_md, 0, {scale_md, offset_md}, g_comp->eng);
1138-
auto c = dnnl::concat(c_pd);
1139-
c.execute(dnnl::stream(g_comp->eng),
1140-
{{DNNL_ARG_MULTIPLE_SRC, scale->mem},
1141-
{DNNL_ARG_MULTIPLE_SRC + 1, offset->mem},
1142-
{DNNL_ARG_DST, weight_mem}});
1138+
auto scale_offset =
1139+
odla_Concat({2, {s, b}}, 0, {2, {2, channels}}, nullptr);
1140+
scale_offset_mem = scale_offset->mem;
11431141
}
11441142
auto op_desc = dnnl::batch_normalization_forward::desc(
11451143
dnnl::prop_kind::forward, input_md, epsilon, flags);
@@ -1148,15 +1146,15 @@ odla_value odla_BatchNormalization(odla_value input,
11481146
auto prim = dnnl::batch_normalization_forward(pd);
11491147
auto ret_mem = dnnl::memory(input_md, g_comp->eng);
11501148

1151-
odla_value v = CreateValue(ret_mem, orig_dims, value_id);
11521149
add_op(prim, {{DNNL_ARG_SRC, input->mem},
11531150
{DNNL_ARG_MEAN, mean->mem},
11541151
{DNNL_ARG_VARIANCE, var->mem},
1155-
{DNNL_ARG_SCALE_SHIFT, weight_mem},
1152+
{DNNL_ARG_SCALE_SHIFT, scale_offset_mem},
11561153
{DNNL_ARG_DST, ret_mem}});
1154+
odla_value v = CreateValue(ret_mem, orig_dims, value_id);
11571155
if (g_comp->opts.bf16_mode == BF16_PERFORMACE_MODE) {
11581156
v->mem = cast_op(v, dnnl::memory::data_type::bf16);
1159-
input->mem = oring_mem;
1157+
input->mem = origin_mem;
11601158
}
11611159

11621160
InterpretIfNeeded();

tests/unittests/lit_cases/test_dnnl/test_batchnorm_epsilon_dnnl.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,5 @@
2929
// RUN: %t_dnnl.exe 0.0001 0 dnnl %data_path/test_batchnorm_epsilon | FileCheck %s
3030
// CHECK: Result Pass
3131
// clang-format on
32-
// XFAIL: *
32+
3333
#include "test_batchnorm_epsilon_dnnl.cc.tmp.main.cc.in"

tests/unittests/lit_cases/test_dnnl/test_batchnorm_example_dnnl.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,5 @@
2929
// RUN: %t_dnnl.exe 0.0001 0 dnnl %data_path/test_batchnorm_example | FileCheck %s
3030
// CHECK: Result Pass
3131
// clang-format on
32-
// XFAIL: *
32+
3333
#include "test_batchnorm_example_dnnl.cc.tmp.main.cc.in"

0 commit comments

Comments
 (0)