@@ -1650,7 +1650,10 @@ def __init__(self,
1650
1650
teacher ,
1651
1651
temperature = None ,
1652
1652
fraction_soft = None ,
1653
- distill_start_step = 0 ,
1653
+ mse_coeff = 0. ,
1654
+ kl_coeff = 0. ,
1655
+ cosine_coeff = 0. ,
1656
+ distill_start_steps = 0 ,
1654
1657
teacher_checkpoint = None ,
1655
1658
initialize_student_weights = False ):
1656
1659
"""Create a StudentTeacher.
@@ -1664,7 +1667,10 @@ def __init__(self,
1664
1667
target cross entropy to the training loss. The rest of the loss will be
1665
1668
the cross entropy with the one-hot actual label. Required only when
1666
1669
training.
1667
- distill_start_step: an int, training steps after which teacher loss is
1670
+ mse_coeff: MSE distillation loss co-efficient.
1671
+ kl_coeff: KL-Divergence distillation loss co-efficient.
1672
+ cosine_coeff: COsine-embedding distillation loss co-efficient.
1673
+ distill_start_steps: an int, training steps after which teacher loss is
1668
1674
incorporated in the overall loss.
1669
1675
teacher_checkpoint: a string, the path to the teacher checkpoint that we
1670
1676
wish to use. Required only when training.
@@ -1676,9 +1682,15 @@ def __init__(self,
1676
1682
self .teacher = teacher
1677
1683
self .temperature = temperature
1678
1684
self .fraction_soft = fraction_soft
1679
- self .distill_start_step = distill_start_step
1685
+ self .distill_start_steps = distill_start_steps
1680
1686
self .teacher_checkpoint = teacher_checkpoint
1681
1687
self .initialize_student_weights = initialize_student_weights
1688
+ self .kl_coeff = kl_coeff
1689
+ self .cosine_coeff = cosine_coeff
1690
+ self .mse_coeff = mse_coeff
1691
+ if (fraction_soft + kl_coeff + cosine_coeff + mse_coeff ) > 1. :
1692
+ raise ValueError ("Distillation co-efficients must not add up to a value "
1693
+ "greater than 1." )
1682
1694
1683
1695
def call_simple (self ,
1684
1696
inputs ,
@@ -1751,15 +1763,40 @@ def call_simple(self,
1751
1763
weights = mtf .cast (mtf .greater (targets , 0 ), soft_loss .dtype )
1752
1764
soft_loss = (mtf .reduce_sum (soft_loss * weights ) /
1753
1765
self .student .loss_denominator (targets , num_microbatches ))
1766
+ if self .kl_coeff > 0. :
1767
+ student_pred = mtf .softmax (student_logits / self .temperature ,
1768
+ output_vocab_dim )
1769
+ kl_loss = mtf .layers .kl_divergence (
1770
+ mtf .stop_gradient (soft_targets ), student_pred , output_vocab_dim ,
1771
+ weights = weights )
1772
+ else :
1773
+ kl_loss = 0.
1774
+ if self .cosine_coeff > 0. :
1775
+ cosine_loss = mtf .layers .cosine_embedding_distill (
1776
+ mtf .stop_gradient (teacher_logits ), student_logits , output_vocab_dim ,
1777
+ weights = weights )
1778
+ else :
1779
+ cosine_loss = 0.
1780
+ if self .mse_coeff > 0. :
1781
+ mse_loss = mtf .layers .kl_divergence (
1782
+ mtf .stop_gradient (teacher_logits ), student_logits , output_vocab_dim ,
1783
+ weights = weights )
1784
+ else :
1785
+ mse_loss = 0.
1754
1786
global_step = tf .train .get_or_create_global_step ()
1755
- current_fraction_soft = tf .cast (
1787
+ distill_loss_fraction = (self .fraction_soft + self .kl_coeff +
1788
+ self .mse_coeff + self .kl_coeff )
1789
+ current_distill_fraction = tf .cast (
1756
1790
tf .cond (
1757
- tf .math .greater (global_step , self .distill_start_step ),
1758
- lambda : self . fraction_soft , lambda : tf .constant (0.0 )),
1791
+ tf .math .greater (global_step , self .distill_start_steps ),
1792
+ lambda : distill_loss_fraction , lambda : tf .constant (0.0 )),
1759
1793
dtype = tf .bfloat16 )
1760
1794
1761
- loss = (1.0 - current_fraction_soft ) * hard_loss \
1762
- + self .temperature ** 2 * current_fraction_soft * soft_loss
1795
+ loss = (1.0 - current_distill_fraction ) * hard_loss \
1796
+ + current_distill_fraction * (
1797
+ self .temperature ** 2 * soft_loss * self .fraction_soft +
1798
+ self .kl_coeff * kl_loss + self .mse_coeff + mse_loss +
1799
+ self .cosine_coeff * cosine_loss )
1763
1800
1764
1801
return student_logits , loss
1765
1802
0 commit comments