2
2
from tqdm import tqdm
3
3
import time
4
4
5
- from KD_Lib .common import BaseClass
5
+ from KD_Lib .KD . common import BaseClass
6
6
7
7
8
- class Pipeline () :
8
+ class Pipeline :
9
9
"""
10
10
Pipeline of knowledge distillation, pruning and quantization methods
11
11
supported by KD_Lib. Sequentially applies a list of methods on the student model.
12
-
12
+
13
13
All the elements in list must implement either train_student, prune or quantize
14
14
methods.
15
15
16
16
:param: steps (list) list of KD_Lib.KD or KD_Lib.Pruning or KD_Lib.Quantization
17
17
:param: epochs (int) number of iterations through whole batch for each method in
18
- list
18
+ list
19
19
:param: plot_losses (bool) Plot a graph of losses during training
20
20
:param: save_model (bool) Save model after performing the list methods
21
21
:param: save_model_pth (str) Path where model is saved if save_model is True
22
22
:param: verbose (int) Verbose
23
23
"""
24
+
24
25
def __init__ (
25
- self ,
26
- steps ,
27
- epochs = 5 ,
28
- plot_losses = True ,
29
- save_model = True ,
30
- save_model_pth = "./models/student.pt" ,
31
- verbose = 0 ):
26
+ self ,
27
+ steps ,
28
+ epochs = 5 ,
29
+ plot_losses = True ,
30
+ save_model = True ,
31
+ save_model_pth = "./models/student.pt" ,
32
+ verbose = 0 ,
33
+ ):
32
34
self .steps = steps
33
35
self .device = device
34
36
self .verbose = verbose
@@ -43,10 +45,12 @@ def _validate_steps(self):
43
45
name , process = zip (* self .steps )
44
46
45
47
for t in process :
46
- if (not hasattr (t , ('train_student' , 'prune' , 'quantize' ))):
47
- raise TypeError ("All the steps must support at least one of "
48
- "train_student, prune or quantize method, {} is not"
49
- " supported yet" .format (str (t )))
48
+ if not hasattr (t , ("train_student" , "prune" , "quantize" )):
49
+ raise TypeError (
50
+ "All the steps must support at least one of "
51
+ "train_student, prune or quantize method, {} is not"
52
+ " supported yet" .format (str (t ))
53
+ )
50
54
51
55
def get_steps (self ):
52
56
return self .steps
@@ -65,38 +69,49 @@ def _fit(self):
65
69
for idx , name , process in self ._iter ():
66
70
print ("Starting {}" .format (name ))
67
71
if idx != 0 :
68
- if hasattr (process , ' train_student' ):
69
- if hasattr (self .steps [idx - 1 ], ' train_student' ):
70
- process .student_model = self .steps [idx - 1 ].student_model
72
+ if hasattr (process , " train_student" ):
73
+ if hasattr (self .steps [idx - 1 ], " train_student" ):
74
+ process .student_model = self .steps [idx - 1 ].student_model
71
75
else :
72
- process .student_model = self .steps [idx - 1 ].model
76
+ process .student_model = self .steps [idx - 1 ].model
73
77
t1 = time .time ()
74
- if hasattr (process , 'train_student' ):
75
- process .train_student (self .epochs , self .plot_losses , self .save_model , self .save_model_path )
76
- elif hasattr (proces , 'prune' ):
78
+ if hasattr (process , "train_student" ):
79
+ process .train_student (
80
+ self .epochs , self .plot_losses , self .save_model , self .save_model_path
81
+ )
82
+ elif hasattr (proces , "prune" ):
77
83
process .prune ()
78
- elif hasattr (process , ' quantize' ):
84
+ elif hasattr (process , " quantize" ):
79
85
process .quantize ()
80
86
else :
81
- raise TypeError ("{} is not supported by the pipeline yet."
82
- .format (process ))
87
+ raise TypeError (
88
+ "{} is not supported by the pipeline yet." .format (process )
89
+ )
83
90
84
91
t2 = time .time () - t1
85
- print ("{} completed in {}hr {}min {}s" .format (name , t2 // (60 * 60 ), t2 // 60 , t2 % 60 )
86
-
92
+ print (
93
+ "{} completed in {}hr {}min {}s" .format (
94
+ name , t2 // (60 * 60 ), t2 // 60 , t2 % 60
95
+ )
96
+ )
97
+
87
98
if self .verbose :
88
99
pbar .update (1 )
89
-
100
+
90
101
if self .verbose :
91
102
pbar .close ()
92
103
93
104
def train (self ):
94
105
"""
95
- Train the (student) model sequentially through the list.
106
+ Train the (student) model sequentially through the list.
96
107
"""
97
108
self ._validate_steps ()
98
109
99
110
t1 = time .time ()
100
111
self ._fit ()
101
112
t2 = time .time () - t1
102
- print ("Pipeline execution completed in {}hr {}min {}s" .format (t2 // (60 * 60 ), t2 // 60 , t2 % 60 )
113
+ print (
114
+ "Pipeline execution completed in {}hr {}min {}s" .format (
115
+ t2 // (60 * 60 ), t2 // 60 , t2 % 60
116
+ )
117
+ )
0 commit comments