Skip to content

Commit a999eed

Browse files
fix typos
1 parent 29ad9b2 commit a999eed

File tree

1 file changed

+45
-30
lines changed

1 file changed

+45
-30
lines changed

KD_Lib/utils/pipeline.py

Lines changed: 45 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,33 +2,35 @@
22
from tqdm import tqdm
33
import time
44

5-
from KD_Lib.common import BaseClass
5+
from KD_Lib.KD.common import BaseClass
66

77

8-
class Pipeline():
8+
class Pipeline:
99
"""
1010
Pipeline of knowledge distillation, pruning and quantization methods
1111
supported by KD_Lib. Sequentially applies a list of methods on the student model.
12-
12+
1313
All the elements in list must implement either train_student, prune or quantize
1414
methods.
1515
1616
:param: steps (list) list of KD_Lib.KD or KD_Lib.Pruning or KD_Lib.Quantization
1717
:param: epochs (int) number of iterations through whole batch for each method in
18-
list
18+
list
1919
:param: plot_losses (bool) Plot a graph of losses during training
2020
:param: save_model (bool) Save model after performing the list methods
2121
:param: save_model_pth (str) Path where model is saved if save_model is True
2222
:param: verbose (int) Verbose
2323
"""
24+
2425
def __init__(
25-
self,
26-
steps,
27-
epochs=5,
28-
plot_losses=True,
29-
save_model=True,
30-
save_model_pth="./models/student.pt",
31-
verbose=0):
26+
self,
27+
steps,
28+
epochs=5,
29+
plot_losses=True,
30+
save_model=True,
31+
save_model_pth="./models/student.pt",
32+
verbose=0,
33+
):
3234
self.steps = steps
3335
self.device = device
3436
self.verbose = verbose
@@ -43,10 +45,12 @@ def _validate_steps(self):
4345
name, process = zip(*self.steps)
4446

4547
for t in process:
46-
if (not hasattr(t, ('train_student', 'prune', 'quantize'))):
47-
raise TypeError("All the steps must support at least one of "
48-
"train_student, prune or quantize method, {} is not"
49-
" supported yet".format(str(t)))
48+
if not hasattr(t, ("train_student", "prune", "quantize")):
49+
raise TypeError(
50+
"All the steps must support at least one of "
51+
"train_student, prune or quantize method, {} is not"
52+
" supported yet".format(str(t))
53+
)
5054

5155
def get_steps(self):
5256
return self.steps
@@ -65,38 +69,49 @@ def _fit(self):
6569
for idx, name, process in self._iter():
6670
print("Starting {}".format(name))
6771
if idx != 0:
68-
if hasattr(process, 'train_student'):
69-
if hasattr(self.steps[idx-1], 'train_student'):
70-
process.student_model = self.steps[idx-1].student_model
72+
if hasattr(process, "train_student"):
73+
if hasattr(self.steps[idx - 1], "train_student"):
74+
process.student_model = self.steps[idx - 1].student_model
7175
else:
72-
process.student_model = self.steps[idx-1].model
76+
process.student_model = self.steps[idx - 1].model
7377
t1 = time.time()
74-
if hasattr(process, 'train_student'):
75-
process.train_student(self.epochs, self.plot_losses, self.save_model, self.save_model_path)
76-
elif hasattr(proces, 'prune'):
78+
if hasattr(process, "train_student"):
79+
process.train_student(
80+
self.epochs, self.plot_losses, self.save_model, self.save_model_path
81+
)
82+
elif hasattr(proces, "prune"):
7783
process.prune()
78-
elif hasattr(process, 'quantize'):
84+
elif hasattr(process, "quantize"):
7985
process.quantize()
8086
else:
81-
raise TypeError("{} is not supported by the pipeline yet."
82-
.format(process))
87+
raise TypeError(
88+
"{} is not supported by the pipeline yet.".format(process)
89+
)
8390

8491
t2 = time.time() - t1
85-
print("{} completed in {}hr {}min {}s".format(name, t2 // (60 * 60), t2 // 60, t2 % 60)
86-
92+
print(
93+
"{} completed in {}hr {}min {}s".format(
94+
name, t2 // (60 * 60), t2 // 60, t2 % 60
95+
)
96+
)
97+
8798
if self.verbose:
8899
pbar.update(1)
89-
100+
90101
if self.verbose:
91102
pbar.close()
92103

93104
def train(self):
94105
"""
95-
Train the (student) model sequentially through the list.
106+
Train the (student) model sequentially through the list.
96107
"""
97108
self._validate_steps()
98109

99110
t1 = time.time()
100111
self._fit()
101112
t2 = time.time() - t1
102-
print("Pipeline execution completed in {}hr {}min {}s".format(t2 // (60 * 60), t2 // 60, t2 % 60)
113+
print(
114+
"Pipeline execution completed in {}hr {}min {}s".format(
115+
t2 // (60 * 60), t2 // 60, t2 % 60
116+
)
117+
)

0 commit comments

Comments
 (0)