Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit 9f15745

Browse files
author
Mesh TensorFlow Team
committed
Add loss functions for multiple-target objectives for distillation.
PiperOrigin-RevId: 356382304
1 parent 9625f34 commit 9f15745

File tree

1 file changed

+99
-0
lines changed

1 file changed

+99
-0
lines changed

mesh_tensorflow/layers.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1101,6 +1101,105 @@ def softmax_cross_entropy_with_logits(logits, targets, vocab_dim, z_loss=0.0):
11011101
return loss
11021102

11031103

1104+
def kl_divergence(y_true, y_pred, reduced_dim, weights=None, epsilon=1e-6):
1105+
"""Kullback-Leibler-Divergence between `y_true` and `y_pred`.
1106+
1107+
Computes: `loss = y_true * log(y_true / y_pred)`
1108+
From: tf.keras.losses.KLDivergence (Custom implementation with mtf)
1109+
See: https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence
1110+
1111+
Args:
1112+
y_true: mtf.Tensor, target predictions (distribution).
1113+
y_pred: mtf.Tensor, actual predictions (distribution).
1114+
reduced_dim: mtf.Dimension, reduction dimension for sum.
1115+
weights: Optional mtf.Tensor, indicator for padded regions.
1116+
epsilon: float, minimum value for numerical stability.
1117+
Returns:
1118+
scalar: K-L Divergence loss.
1119+
Raises:
1120+
ValueError: if the shapes do not match or reduced_dim is not valid.
1121+
"""
1122+
if set(y_true.shape.dims) != set(y_pred.shape.dims):
1123+
raise ValueError(
1124+
"`y_true` and `y_pred` must be of the same shape. "
1125+
f"Currently they are {y_true.shape.dims} and {y_pred.shape.dims}")
1126+
if reduced_dim not in y_true.shape.dims:
1127+
raise ValueError(
1128+
f"`reduced_dim` must be a valid dimension (from {y_true.shape.dims}).")
1129+
if weights is None:
1130+
weights = 1.
1131+
1132+
def _clip(x, min_value, max_value):
1133+
# Clip values for numerical stability.
1134+
x = mtf.maximum(x, min_value)
1135+
x = mtf.minimum(x, max_value)
1136+
return x
1137+
1138+
y_true = _clip(y_true, epsilon, 1.)
1139+
y_pred = _clip(y_pred, epsilon, 1.)
1140+
return mtf.reduce_sum(weights * y_true * mtf.log(y_true / y_pred))
1141+
1142+
1143+
def mean_squared_error(y_true, y_pred, weights=None):
1144+
"""L2-Loss between `y_true` and `y_pred`.
1145+
1146+
Args:
1147+
y_true: mtf.Tensor, target logits.
1148+
y_pred: mtf.Tensor, actual logits.
1149+
weights: Optional mtf.Tensor, indicator for padded regions.
1150+
Returns:
1151+
scalar: L2 loss.
1152+
Raises:
1153+
ValueError: if the shapes do not match or reduced_dim is not valid.
1154+
"""
1155+
if set(y_true.shape.dims) != set(y_pred.shape.dims):
1156+
raise ValueError(
1157+
"`y_true` and `y_pred` must be of the same shape. "
1158+
f"Currently they are {y_true.shape.dims} and {y_pred.shape.dims}")
1159+
if weights is None:
1160+
weights = 1.
1161+
return mtf.reduce_sum(weights * mtf.square(y_true - y_pred))
1162+
1163+
1164+
def cosine_embedding_distill(y_true, y_pred, reduced_dim, weights=None,
1165+
epsilon=1e-6):
1166+
"""Cosine embedding loss for distillation from teacher to student logits.
1167+
1168+
See: https://arxiv.org/abs/1910.01108 (DistilBert) and
1169+
https://github.com/huggingface/transformers/tree/master/examples/
1170+
research_projects/distillation.
1171+
1172+
Args:
1173+
y_true: mtf.Tensor, teacher logits.
1174+
y_pred: mtf.Tensor, student logits.
1175+
reduced_dim: mtf.Dimension, reduction dimension for sum.
1176+
weights: Optional mtf.Tensor, indicator for padded regions.
1177+
epsilon: float, for numerical stability.
1178+
Returns:
1179+
scalar: mean cosine embedding distance.
1180+
Raises:
1181+
ValueError: if the shapes do not match or reduced_dim is not valid.
1182+
"""
1183+
if set(y_true.shape.dims) != set(y_pred.shape.dims):
1184+
raise ValueError(
1185+
"`y_true` and `y_pred` must be of the same shape. "
1186+
f"Currently they are {y_true.shape.dims} and {y_pred.shape.dims}")
1187+
if reduced_dim not in y_true.shape.dims:
1188+
raise ValueError(
1189+
f"`reduced_dim` must be a valid dimension (from {y_true.shape.dims}).")
1190+
if weights is None:
1191+
weights = 1.
1192+
1193+
prod_sum = mtf.reduce_sum(y_true * y_pred, reduced_dim=reduced_dim)
1194+
y_true_sq_sum = mtf.reduce_sum(y_true * y_true, reduced_dim=reduced_dim)
1195+
y_pred_sq_sum = mtf.reduce_sum(y_pred * y_pred, reduced_dim=reduced_dim)
1196+
inv_denom = mtf.rsqrt(y_true_sq_sum * y_pred_sq_sum + epsilon)
1197+
cos = prod_sum * inv_denom
1198+
# TODO(vinaysrao): Turn this into a more general cosine embedding loss with
1199+
# a `targets` tensor.
1200+
return mtf.reduce_sum(weights * (1. - cos))
1201+
1202+
11041203
def sigmoid_cross_entropy_with_logits(logits, targets):
11051204
"""Sigmoid cross-entropy loss.
11061205

0 commit comments

Comments
 (0)