@@ -1101,6 +1101,105 @@ def softmax_cross_entropy_with_logits(logits, targets, vocab_dim, z_loss=0.0):
1101
1101
return loss
1102
1102
1103
1103
1104
+ def kl_divergence (y_true , y_pred , reduced_dim , weights = None , epsilon = 1e-6 ):
1105
+ """Kullback-Leibler-Divergence between `y_true` and `y_pred`.
1106
+
1107
+ Computes: `loss = y_true * log(y_true / y_pred)`
1108
+ From: tf.keras.losses.KLDivergence (Custom implementation with mtf)
1109
+ See: https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence
1110
+
1111
+ Args:
1112
+ y_true: mtf.Tensor, target predictions (distribution).
1113
+ y_pred: mtf.Tensor, actual predictions (distribution).
1114
+ reduced_dim: mtf.Dimension, reduction dimension for sum.
1115
+ weights: Optional mtf.Tensor, indicator for padded regions.
1116
+ epsilon: float, minimum value for numerical stability.
1117
+ Returns:
1118
+ scalar: K-L Divergence loss.
1119
+ Raises:
1120
+ ValueError: if the shapes do not match or reduced_dim is not valid.
1121
+ """
1122
+ if set (y_true .shape .dims ) != set (y_pred .shape .dims ):
1123
+ raise ValueError (
1124
+ "`y_true` and `y_pred` must be of the same shape. "
1125
+ f"Currently they are { y_true .shape .dims } and { y_pred .shape .dims } " )
1126
+ if reduced_dim not in y_true .shape .dims :
1127
+ raise ValueError (
1128
+ f"`reduced_dim` must be a valid dimension (from { y_true .shape .dims } )." )
1129
+ if weights is None :
1130
+ weights = 1.
1131
+
1132
+ def _clip (x , min_value , max_value ):
1133
+ # Clip values for numerical stability.
1134
+ x = mtf .maximum (x , min_value )
1135
+ x = mtf .minimum (x , max_value )
1136
+ return x
1137
+
1138
+ y_true = _clip (y_true , epsilon , 1. )
1139
+ y_pred = _clip (y_pred , epsilon , 1. )
1140
+ return mtf .reduce_sum (weights * y_true * mtf .log (y_true / y_pred ))
1141
+
1142
+
1143
+ def mean_squared_error (y_true , y_pred , weights = None ):
1144
+ """L2-Loss between `y_true` and `y_pred`.
1145
+
1146
+ Args:
1147
+ y_true: mtf.Tensor, target logits.
1148
+ y_pred: mtf.Tensor, actual logits.
1149
+ weights: Optional mtf.Tensor, indicator for padded regions.
1150
+ Returns:
1151
+ scalar: L2 loss.
1152
+ Raises:
1153
+ ValueError: if the shapes do not match or reduced_dim is not valid.
1154
+ """
1155
+ if set (y_true .shape .dims ) != set (y_pred .shape .dims ):
1156
+ raise ValueError (
1157
+ "`y_true` and `y_pred` must be of the same shape. "
1158
+ f"Currently they are { y_true .shape .dims } and { y_pred .shape .dims } " )
1159
+ if weights is None :
1160
+ weights = 1.
1161
+ return mtf .reduce_sum (weights * mtf .square (y_true - y_pred ))
1162
+
1163
+
1164
+ def cosine_embedding_distill (y_true , y_pred , reduced_dim , weights = None ,
1165
+ epsilon = 1e-6 ):
1166
+ """Cosine embedding loss for distillation from teacher to student logits.
1167
+
1168
+ See: https://arxiv.org/abs/1910.01108 (DistilBert) and
1169
+ https://github.com/huggingface/transformers/tree/master/examples/
1170
+ research_projects/distillation.
1171
+
1172
+ Args:
1173
+ y_true: mtf.Tensor, teacher logits.
1174
+ y_pred: mtf.Tensor, student logits.
1175
+ reduced_dim: mtf.Dimension, reduction dimension for sum.
1176
+ weights: Optional mtf.Tensor, indicator for padded regions.
1177
+ epsilon: float, for numerical stability.
1178
+ Returns:
1179
+ scalar: mean cosine embedding distance.
1180
+ Raises:
1181
+ ValueError: if the shapes do not match or reduced_dim is not valid.
1182
+ """
1183
+ if set (y_true .shape .dims ) != set (y_pred .shape .dims ):
1184
+ raise ValueError (
1185
+ "`y_true` and `y_pred` must be of the same shape. "
1186
+ f"Currently they are { y_true .shape .dims } and { y_pred .shape .dims } " )
1187
+ if reduced_dim not in y_true .shape .dims :
1188
+ raise ValueError (
1189
+ f"`reduced_dim` must be a valid dimension (from { y_true .shape .dims } )." )
1190
+ if weights is None :
1191
+ weights = 1.
1192
+
1193
+ prod_sum = mtf .reduce_sum (y_true * y_pred , reduced_dim = reduced_dim )
1194
+ y_true_sq_sum = mtf .reduce_sum (y_true * y_true , reduced_dim = reduced_dim )
1195
+ y_pred_sq_sum = mtf .reduce_sum (y_pred * y_pred , reduced_dim = reduced_dim )
1196
+ inv_denom = mtf .rsqrt (y_true_sq_sum * y_pred_sq_sum + epsilon )
1197
+ cos = prod_sum * inv_denom
1198
+ # TODO(vinaysrao): Turn this into a more general cosine embedding loss with
1199
+ # a `targets` tensor.
1200
+ return mtf .reduce_sum (weights * (1. - cos ))
1201
+
1202
+
1104
1203
def sigmoid_cross_entropy_with_logits (logits , targets ):
1105
1204
"""Sigmoid cross-entropy loss.
1106
1205
0 commit comments