Skip to content

Commit 934d1d3

Browse files
Error handling and minor bug fixes in CSKD (#100)
* pairwise sampler added and csdk updated * links added in init * Finalised and logs added * CSKD added with tests * Docs added * Testing internal kdloss * Adding docstrings and paper summary in tutorial * Minor correction in docs * Error handling for non-none teacher added * cskd reformatted
1 parent d30d065 commit 934d1d3

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

KD_Lib/KD/vision/CSKD/cskd.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@
1111

1212
class CSKD(BaseClass):
1313
"""
14-
Implementation of assisted Knowledge distillation from the paper "Improved Knowledge
15-
Distillation via Teacher Assistant" https://arxiv.org/pdf/1902.03393.pdf
14+
Implementation of "Regularizing Class-wise Predictions via Self-knowledge Distillation"
15+
https://arxiv.org/pdf/2003.13964.pdf
1616
17-
:param teacher_model (torch.nn.Module): Teacher model
17+
:param teacher_model (torch.nn.Module): Teacher model -> Should be None
1818
:param student_model (torch.nn.Module): Student model
1919
:param train_loader (torch.utils.data.DataLoader): Dataloader for training
2020
:param val_loader (torch.utils.data.DataLoader): Dataloader for validation/testing
21-
:param optimizer_teacher (torch.optim.*): Optimizer used for training teacher
21+
:param optimizer_teacher (torch.optim.*): Optimizer used for training teacher -> Should be None
2222
:param optimizer_student (torch.optim.*): Optimizer used for training student
2323
:param loss_fn (torch.nn.Module): Calculates loss during distillation
2424
:param temp (float): Temperature parameter for distillation
@@ -60,6 +60,11 @@ def __init__(
6060
logdir,
6161
)
6262
self.lamda = lamda
63+
if teacher_model is not None or optimizer_teacher is not None:
64+
print(
65+
"Error!!! Teacher model and Teacher optimizer should be None for self-distillation, please refer to the documentation."
66+
)
67+
assert teacher_model == None
6368

6469
def calculate_kd_loss(self, y_pred_pair_1, y_pred_pair_2):
6570
"""

0 commit comments

Comments
 (0)