Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit 972a038

Browse files
author
Mesh TensorFlow Team
committed
Use multiple target objectives for distillation. Also see cl/356382304
PiperOrigin-RevId: 356382406
1 parent 9625f34 commit 972a038

File tree

1 file changed

+45
-8
lines changed

1 file changed

+45
-8
lines changed

mesh_tensorflow/transformer/transformer.py

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1650,7 +1650,10 @@ def __init__(self,
16501650
teacher,
16511651
temperature=None,
16521652
fraction_soft=None,
1653-
distill_start_step=0,
1653+
mse_coeff=0.,
1654+
kl_coeff=0.,
1655+
cosine_coeff=0.,
1656+
distill_start_steps=0,
16541657
teacher_checkpoint=None,
16551658
initialize_student_weights=False):
16561659
"""Create a StudentTeacher.
@@ -1664,7 +1667,10 @@ def __init__(self,
16641667
target cross entropy to the training loss. The rest of the loss will be
16651668
the cross entropy with the one-hot actual label. Required only when
16661669
training.
1667-
distill_start_step: an int, training steps after which teacher loss is
1670+
mse_coeff: MSE distillation loss co-efficient.
1671+
kl_coeff: KL-Divergence distillation loss co-efficient.
1672+
cosine_coeff: COsine-embedding distillation loss co-efficient.
1673+
distill_start_steps: an int, training steps after which teacher loss is
16681674
incorporated in the overall loss.
16691675
teacher_checkpoint: a string, the path to the teacher checkpoint that we
16701676
wish to use. Required only when training.
@@ -1676,9 +1682,15 @@ def __init__(self,
16761682
self.teacher = teacher
16771683
self.temperature = temperature
16781684
self.fraction_soft = fraction_soft
1679-
self.distill_start_step = distill_start_step
1685+
self.distill_start_steps = distill_start_steps
16801686
self.teacher_checkpoint = teacher_checkpoint
16811687
self.initialize_student_weights = initialize_student_weights
1688+
self.kl_coeff = kl_coeff
1689+
self.cosine_coeff = cosine_coeff
1690+
self.mse_coeff = mse_coeff
1691+
if (fraction_soft + kl_coeff + cosine_coeff + mse_coeff) > 1.:
1692+
raise ValueError("Distillation co-efficients must not add up to a value "
1693+
"greater than 1.")
16821694

16831695
def call_simple(self,
16841696
inputs,
@@ -1751,15 +1763,40 @@ def call_simple(self,
17511763
weights = mtf.cast(mtf.greater(targets, 0), soft_loss.dtype)
17521764
soft_loss = (mtf.reduce_sum(soft_loss * weights) /
17531765
self.student.loss_denominator(targets, num_microbatches))
1766+
if self.kl_coeff > 0.:
1767+
student_pred = mtf.softmax(student_logits / self.temperature,
1768+
output_vocab_dim)
1769+
kl_loss = mtf.layers.kl_divergence(
1770+
mtf.stop_gradient(soft_targets), student_pred, output_vocab_dim,
1771+
weights=weights)
1772+
else:
1773+
kl_loss = 0.
1774+
if self.cosine_coeff > 0.:
1775+
cosine_loss = mtf.layers.cosine_embedding_distill(
1776+
mtf.stop_gradient(teacher_logits), student_logits, output_vocab_dim,
1777+
weights=weights)
1778+
else:
1779+
cosine_loss = 0.
1780+
if self.mse_coeff > 0.:
1781+
mse_loss = mtf.layers.kl_divergence(
1782+
mtf.stop_gradient(teacher_logits), student_logits, output_vocab_dim,
1783+
weights=weights)
1784+
else:
1785+
mse_loss = 0.
17541786
global_step = tf.train.get_or_create_global_step()
1755-
current_fraction_soft = tf.cast(
1787+
distill_loss_fraction = (self.fraction_soft + self.kl_coeff +
1788+
self.mse_coeff + self.kl_coeff)
1789+
current_distill_fraction = tf.cast(
17561790
tf.cond(
1757-
tf.math.greater(global_step, self.distill_start_step),
1758-
lambda: self.fraction_soft, lambda: tf.constant(0.0)),
1791+
tf.math.greater(global_step, self.distill_start_steps),
1792+
lambda: distill_loss_fraction, lambda: tf.constant(0.0)),
17591793
dtype=tf.bfloat16)
17601794

1761-
loss = (1.0 - current_fraction_soft) * hard_loss \
1762-
+ self.temperature**2 * current_fraction_soft * soft_loss
1795+
loss = (1.0 - current_distill_fraction) * hard_loss \
1796+
+ current_distill_fraction * (
1797+
self.temperature**2 * soft_loss * self.fraction_soft +
1798+
self.kl_coeff * kl_loss + self.mse_coeff + mse_loss +
1799+
self.cosine_coeff * cosine_loss)
17631800

17641801
return student_logits, loss
17651802

0 commit comments

Comments
 (0)