@@ -1090,7 +1090,7 @@ odla_value odla_BatchNormalization(odla_value input,
1090
1090
odla_value offset, odla_float32 scalar_scale,
1091
1091
odla_float32 scalar_offset,
1092
1092
const odla_value_id value_id) {
1093
- dnnl::memory oring_mem ;
1093
+ dnnl::memory origin_mem ;
1094
1094
dnnl::memory::data_type dtype = input->mem .get_desc ().data_type ();
1095
1095
// black list op should convert to fp32
1096
1096
bool bf16_mode = (dtype == dnnl::memory::data_type::bf16 ||
@@ -1105,7 +1105,7 @@ odla_value odla_BatchNormalization(odla_value input,
1105
1105
scale->mem = cast_op (scale, dnnl::memory::data_type::f32);
1106
1106
offset->mem = cast_op (offset, dnnl::memory::data_type::f32);
1107
1107
}
1108
- oring_mem = input->mem ;
1108
+ origin_mem = input->mem ;
1109
1109
input->mem = f32_input_mem;
1110
1110
}
1111
1111
@@ -1121,25 +1121,23 @@ odla_value odla_BatchNormalization(odla_value input,
1121
1121
}
1122
1122
1123
1123
unsigned channels = input_dims.dims [1 ];
1124
- dnnl::memory::desc weight_md (dnnl::memory::dims{2 , channels}, type,
1125
- dnnl::memory::format_tag::nc);
1126
- dnnl::memory weight_mem = dnnl::memory (weight_md, g_comp->eng );
1127
-
1128
- if (scale != nullptr && offset != nullptr ) {
1124
+ dnnl::memory scale_offset_mem = dnnl::memory ();
1125
+ if (scale != nullptr || offset != nullptr || scalar_offset != 0 .0F ||
1126
+ scalar_scale != 1 .0F ) {
1127
+ // make a tensor [scale, bias].
1128
+ auto get_value = [channels](odla_value x, float scalar) {
1129
+ if (x == nullptr ) {
1130
+ x = odla_CreateConstant ({ODLA_FLOAT32, {2 , {1 , 1 }}}, &scalar,
1131
+ nullptr ); // FIXME: copy to buf
1132
+ }
1133
+ return odla_Reshape (x, {2 , {1 , channels}}, nullptr );
1134
+ };
1135
+ odla_value s = get_value (scale, scalar_scale);
1136
+ odla_value b = get_value (offset, scalar_offset);
1129
1137
flags |= dnnl::normalization_flags::use_scale_shift;
1130
- auto scale_md =
1131
- dnnl::memory::desc ({1 , channels}, type, dnnl::memory::format_tag::nc);
1132
- auto scale_mem =
1133
- dnnl::memory (scale_md, g_comp->eng , scale->mem .get_data_handle ());
1134
- auto offset_md =
1135
- dnnl::memory::desc ({1 , channels}, type, dnnl::memory::format_tag::nc);
1136
- auto c_pd = dnnl::concat::primitive_desc (
1137
- weight_md, 0 , {scale_md, offset_md}, g_comp->eng );
1138
- auto c = dnnl::concat (c_pd);
1139
- c.execute (dnnl::stream (g_comp->eng ),
1140
- {{DNNL_ARG_MULTIPLE_SRC, scale->mem },
1141
- {DNNL_ARG_MULTIPLE_SRC + 1 , offset->mem },
1142
- {DNNL_ARG_DST, weight_mem}});
1138
+ auto scale_offset =
1139
+ odla_Concat ({2 , {s, b}}, 0 , {2 , {2 , channels}}, nullptr );
1140
+ scale_offset_mem = scale_offset->mem ;
1143
1141
}
1144
1142
auto op_desc = dnnl::batch_normalization_forward::desc (
1145
1143
dnnl::prop_kind::forward, input_md, epsilon, flags);
@@ -1148,15 +1146,15 @@ odla_value odla_BatchNormalization(odla_value input,
1148
1146
auto prim = dnnl::batch_normalization_forward (pd);
1149
1147
auto ret_mem = dnnl::memory (input_md, g_comp->eng );
1150
1148
1151
- odla_value v = CreateValue (ret_mem, orig_dims, value_id);
1152
1149
add_op (prim, {{DNNL_ARG_SRC, input->mem },
1153
1150
{DNNL_ARG_MEAN, mean->mem },
1154
1151
{DNNL_ARG_VARIANCE, var->mem },
1155
- {DNNL_ARG_SCALE_SHIFT, weight_mem },
1152
+ {DNNL_ARG_SCALE_SHIFT, scale_offset_mem },
1156
1153
{DNNL_ARG_DST, ret_mem}});
1154
+ odla_value v = CreateValue (ret_mem, orig_dims, value_id);
1157
1155
if (g_comp->opts .bf16_mode == BF16_PERFORMACE_MODE) {
1158
1156
v->mem = cast_op (v, dnnl::memory::data_type::bf16);
1159
- input->mem = oring_mem ;
1157
+ input->mem = origin_mem ;
1160
1158
}
1161
1159
1162
1160
InterpretIfNeeded ();
0 commit comments