Skip to content

Commit ccddfb9

Browse files
add pipeline
1 parent 934d1d3 commit ccddfb9

File tree

4 files changed

+216
-54
lines changed

4 files changed

+216
-54
lines changed

KD_Lib/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .pipeline import Pipeline

KD_Lib/utils/pipeline.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
from itertools import islice
2+
from tqdm import tqdm
3+
import time
4+
5+
from KD_Lib.common import BaseClass
6+
7+
8+
class Pipeline():
9+
"""
10+
Pipeline of knowledge distillation, pruning and quantization methods
11+
supported by KD_Lib. Sequentially applies a list of methods on the student model.
12+
13+
All the elements in list must implement either train_student, prune or quantize
14+
methods.
15+
16+
:param: steps (list) list of KD_Lib.KD or KD_Lib.Pruning or KD_Lib.Quantization
17+
:param: epochs (int) number of iterations through whole batch for each method in
18+
list
19+
:param: plot_losses (bool) Plot a graph of losses during training
20+
:param: save_model (bool) Save model after performing the list methods
21+
:param: save_model_pth (str) Path where model is saved if save_model is True
22+
:param: verbose (int) Verbose
23+
"""
24+
def __init__(
25+
self,
26+
steps,
27+
epochs=5,
28+
plot_losses=True,
29+
save_model=True,
30+
save_model_pth="./models/student.pt",
31+
verbose=0):
32+
self.steps = steps
33+
self.device = device
34+
self.verbose = verbose
35+
36+
self.plot_losses = plot_losses
37+
self.save_model = save_model
38+
self.save_model_path = save_model_pth
39+
self._validate_steps()
40+
self.epochs = epochs
41+
42+
def _validate_steps(self):
43+
name, process = zip(*self.steps)
44+
45+
for t in process:
46+
if (not hasattr(t, ('train_student', 'prune', 'quantize'))):
47+
raise TypeError("All the steps must support at least one of "
48+
"train_student, prune or quantize method, {} is not"
49+
" supported yet".format(str(t)))
50+
51+
def get_steps(self):
52+
return self.steps
53+
54+
def _iter(self, num_steps=-1):
55+
_length = len(self.steps) if num_steps == -1 else num_steps
56+
57+
for idx, (name, process) in enumerate(islice(self.steps, 0, _length)):
58+
yield idx, name, process
59+
60+
def _fit(self):
61+
62+
if self.verbose:
63+
pbar = tqdm(total=len(self))
64+
65+
for idx, name, process in self._iter():
66+
print("Starting {}".format(name))
67+
if idx != 0:
68+
if hasattr(process, 'train_student'):
69+
if hasattr(self.steps[idx-1], 'train_student'):
70+
process.student_model = self.steps[idx-1].student_model
71+
else:
72+
process.student_model = self.steps[idx-1].model
73+
t1 = time.time()
74+
if hasattr(process, 'train_student'):
75+
process.train_student(self.epochs, self.plot_losses, self.save_model, self.save_model_path)
76+
elif hasattr(proces, 'prune'):
77+
process.prune()
78+
elif hasattr(process, 'quantize'):
79+
process.quantize()
80+
else:
81+
raise TypeError("{} is not supported by the pipeline yet."
82+
.format(process))
83+
84+
t2 = time.time() - t1
85+
print("{} completed in {}hr {}min {}s".format(name, t2 // (60 * 60), t2 // 60, t2 % 60)
86+
87+
if self.verbose:
88+
pbar.update(1)
89+
90+
if self.verbose:
91+
pbar.close()
92+
93+
def train(self):
94+
"""
95+
Train the (student) model sequentially through the list.
96+
"""
97+
self._validate_steps()
98+
99+
t1 = time.time()
100+
self._fit()
101+
t2 = time.time() - t1
102+
print("Pipeline execution completed in {}hr {}min {}s".format(t2 // (60 * 60), t2 // 60, t2 % 60)

setup.py

Lines changed: 61 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -13,67 +13,74 @@
1313
LONG_DESCRIPTION = f.read()
1414

1515
# Define the keywords
16-
KEYWORDS = ["Knowledge Distillation", "Pruning", "Quantization", "pytorch", "machine learning", "deep learning"]
16+
KEYWORDS = [
17+
"Knowledge Distillation",
18+
"Pruning",
19+
"Quantization",
20+
"pytorch",
21+
"machine learning",
22+
"deep learning",
23+
]
1724
REQUIRE_PATH = "requirements.txt"
1825
PROJECT = os.path.abspath(os.path.dirname(__file__))
19-
setup_requirements = ['pytest-runner']
26+
setup_requirements = ["pytest-runner"]
2027

21-
test_requirements = ['pytest', 'pytest-cov']
28+
test_requirements = ["pytest", "pytest-cov"]
2229

2330
requirements = [
24-
'pip==19.3.1',
25-
'transformers==4.6.1',
26-
'sacremoses',
27-
'tokenizers==0.10.1',
28-
'huggingface-hub==0.0.8',
29-
'torchtext==0.9.1',
30-
'bumpversion==0.5.3',
31-
'wheel==0.32.1',
32-
'watchdog==0.9.0',
33-
'flake8==3.5.0',
34-
'tox==3.5.2',
35-
'coverage==4.5.1',
36-
'Sphinx==1.8.1',
37-
'twine==1.12.1',
38-
'pytest==3.8.2',
39-
'pytest-runner==4.2',
40-
'pytest-cov==2.6.1',
41-
'matplotlib==3.2.1',
42-
'torch==1.8.1',
43-
'torchvision==0.9.1',
44-
'tensorboard==2.2.1',
45-
'contextlib2==0.6.0.post1',
46-
'pandas==1.0.1',
47-
'tqdm==4.42.1',
48-
'numpy==1.18.1',
49-
'sphinx-rtd-theme==0.5.0',
31+
"pip==19.3.1",
32+
"transformers==4.6.1",
33+
"sacremoses",
34+
"tokenizers==0.10.1",
35+
"huggingface-hub==0.0.8",
36+
"torchtext==0.9.1",
37+
"bumpversion==0.5.3",
38+
"wheel==0.32.1",
39+
"watchdog==0.9.0",
40+
"flake8==3.5.0",
41+
"tox==3.5.2",
42+
"coverage==4.5.1",
43+
"Sphinx==1.8.1",
44+
"twine==1.12.1",
45+
"pytest==3.8.2",
46+
"pytest-runner==4.2",
47+
"pytest-cov==2.6.1",
48+
"matplotlib==3.2.1",
49+
"torch==1.8.1",
50+
"torchvision==0.9.1",
51+
"tensorboard==2.2.1",
52+
"contextlib2==0.6.0.post1",
53+
"pandas==1.0.1",
54+
"tqdm==4.42.1",
55+
"numpy==1.18.1",
56+
"sphinx-rtd-theme==0.5.0",
5057
]
5158

5259

5360
if __name__ == "__main__":
5461
setup(
55-
author="Het Shah",
56-
author_email='divhet163@gmail.com',
57-
classifiers=[
58-
'Development Status :: 2 - Pre-Alpha',
59-
'Intended Audience :: Developers',
60-
'License :: OSI Approved :: MIT License',
61-
'Natural Language :: English',
62-
'Programming Language :: Python :: 3.6',
63-
'Programming Language :: Python :: 3.7',
64-
],
65-
description="A Pytorch Library to help extend all Knowledge Distillation works",
66-
install_requires=requirements,
67-
license="MIT license",
68-
long_description=LONG_DESCRIPTION,
69-
include_package_data=True,
70-
keywords=KEYWORDS,
71-
name='KD_Lib',
72-
packages=find_packages(where=PROJECT),
73-
setup_requires=setup_requirements,
74-
test_suite="tests",
75-
tests_require=test_requirements,
76-
url="https://github.com/SforAiDL/KD_Lib",
77-
version='0.0.29',
78-
zip_safe=False,
79-
)
62+
author="Het Shah",
63+
author_email="divhet163@gmail.com",
64+
classifiers=[
65+
"Development Status :: 2 - Pre-Alpha",
66+
"Intended Audience :: Developers",
67+
"License :: OSI Approved :: MIT License",
68+
"Natural Language :: English",
69+
"Programming Language :: Python :: 3.6",
70+
"Programming Language :: Python :: 3.7",
71+
],
72+
description="A Pytorch Library to help extend all Knowledge Distillation works",
73+
install_requires=requirements,
74+
license="MIT license",
75+
long_description=LONG_DESCRIPTION,
76+
include_package_data=True,
77+
keywords=KEYWORDS,
78+
name="KD_Lib",
79+
packages=find_packages(where=PROJECT),
80+
setup_requires=setup_requirements,
81+
test_suite="tests",
82+
tests_require=test_requirements,
83+
url="https://github.com/SforAiDL/KD_Lib",
84+
version="0.0.29",
85+
zip_safe=False,
86+
)

tests/test_pipeline.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from KD_Lib.utils import Pipeline
2+
from KD_Lib.KD import VanillaKD
3+
from KD_Lib.Pruning import Lottery_Tickets_Pruner
4+
from KD_Lib.Quantization import Dynamic_Quantizer
5+
from KD_Lib.models import Shallow
6+
7+
import torch
8+
9+
10+
train_loader = torch.utils.data.DataLoader(
11+
datasets.MNIST(
12+
"mnist_data",
13+
train=True,
14+
download=True,
15+
transform=transforms.Compose(
16+
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
17+
),
18+
),
19+
batch_size=32,
20+
shuffle=True,
21+
)
22+
23+
test_loader = torch.utils.data.DataLoader(
24+
datasets.MNIST(
25+
"mnist_data",
26+
train=False,
27+
transform=transforms.Compose(
28+
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
29+
),
30+
),
31+
batch_size=32,
32+
shuffle=True,
33+
)
34+
35+
36+
def test_Pipeline():
37+
teacher = Shallow(hidden_size=400)
38+
student = Shallow(hidden_size=100)
39+
40+
t_optimizer = optim.SGD(teac.parameters(), 0.01)
41+
s_optimizer = optim.SGD(stud.parameters(), 0.01)
42+
43+
distiller = VanillaKD(
44+
teacher, student, train_loader, test_loader, t_optimizer, s_optimizer
45+
)
46+
47+
pruner = Lottery_Tickets_Pruner(student, train_loader, test_loader)
48+
49+
quantizer = Dynamic_Quantizer(student, test_loader, {torch.nn.Linear})
50+
51+
pipe = Pipeline([distiller, pruner, quantizer], 1)
52+
pipe.train()

0 commit comments

Comments
 (0)