diff --git a/models/official/efficientnet/efficientnet_model.py b/models/official/efficientnet/efficientnet_model.py index 9b969f75f..0290330eb 100644 --- a/models/official/efficientnet/efficientnet_model.py +++ b/models/official/efficientnet/efficientnet_model.py @@ -421,7 +421,8 @@ def call(self, inputs, training=True, survival_prob=None): if self._block_args.id_skip: if all( s == 1 for s in self._block_args.strides - ) and inputs.get_shape().as_list()[-1] == x.get_shape().as_list()[-1]: + ) and (inputs.get_shape().as_list()[self._channel_axis] == + x.get_shape().as_list()[self._channel_axis]): # Apply only if skip connection presents. if survival_prob: x = utils.drop_connect(x, training, survival_prob)