File tree Expand file tree Collapse file tree 1 file changed +9
-4
lines changed Expand file tree Collapse file tree 1 file changed +9
-4
lines changed Original file line number Diff line number Diff line change 11
11
12
12
class CSKD (BaseClass ):
13
13
"""
14
- Implementation of assisted Knowledge distillation from the paper "Improved Knowledge
15
- Distillation via Teacher Assistant" https://arxiv.org/pdf/1902.03393 .pdf
14
+ Implementation of "Regularizing Class-wise Predictions via Self-knowledge Distillation"
15
+ https://arxiv.org/pdf/2003.13964 .pdf
16
16
17
- :param teacher_model (torch.nn.Module): Teacher model
17
+ :param teacher_model (torch.nn.Module): Teacher model -> Should be None
18
18
:param student_model (torch.nn.Module): Student model
19
19
:param train_loader (torch.utils.data.DataLoader): Dataloader for training
20
20
:param val_loader (torch.utils.data.DataLoader): Dataloader for validation/testing
21
- :param optimizer_teacher (torch.optim.*): Optimizer used for training teacher
21
+ :param optimizer_teacher (torch.optim.*): Optimizer used for training teacher -> Should be None
22
22
:param optimizer_student (torch.optim.*): Optimizer used for training student
23
23
:param loss_fn (torch.nn.Module): Calculates loss during distillation
24
24
:param temp (float): Temperature parameter for distillation
@@ -60,6 +60,11 @@ def __init__(
60
60
logdir ,
61
61
)
62
62
self .lamda = lamda
63
+ if teacher_model is not None or optimizer_teacher is not None :
64
+ print (
65
+ "Error!!! Teacher model and Teacher optimizer should be None for self-distillation, please refer to the documentation."
66
+ )
67
+ assert teacher_model == None
63
68
64
69
def calculate_kd_loss (self , y_pred_pair_1 , y_pred_pair_2 ):
65
70
"""
You can’t perform that action at this time.
0 commit comments