Skip to content

Commit 0f23c98

Browse files
committed
added onnx export (but check failing on models) + load_state_dict fixed for old checkpoints
1 parent 48280da commit 0f23c98

File tree

2 files changed

+49
-3
lines changed

2 files changed

+49
-3
lines changed

main.py

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111

1212
import os
1313
import argparse
14+
import onnxruntime as ort
15+
import numpy as np
1416

1517
from nbdt.utils import (
1618
progress_bar, generate_fname, generate_kwargs, Colors, maybe_install_wordnet
@@ -39,6 +41,7 @@
3941
parser.add_argument('--pretrained', action='store_true',
4042
help='Download pretrained model. Not all models support this.')
4143
parser.add_argument('--eval', help='eval only', action='store_true')
44+
parser.add_argument('--onnx', help='export only', action='store_true')
4245

4346
# options specific to this project and its dataloaders
4447
parser.add_argument('--loss', choices=loss.names, default='CrossEntropyLoss')
@@ -141,6 +144,19 @@
141144
checkpoint_path = './checkpoint/{}.pth'.format(checkpoint_fname)
142145
print(f'==> Checkpoints will be saved to: {checkpoint_path}')
143146

147+
148+
# TODO(alvin): fix checkpoint structure so that this isn't neededd
149+
def load_state_dict(state_dict):
150+
try:
151+
net.load_state_dict(state_dict)
152+
except RuntimeError as e:
153+
if 'Missing key(s) in state_dict:' in str(e):
154+
net.load_state_dict({
155+
key.replace('module.', '', 1): value
156+
for key, value in state_dict.items()
157+
})
158+
159+
144160
resume_path = args.path_resume or checkpoint_path
145161
if args.resume:
146162
# Load checkpoint.
@@ -149,18 +165,19 @@
149165
if not os.path.exists(resume_path):
150166
print('==> No checkpoint found. Skipping...')
151167
else:
152-
checkpoint = torch.load(resume_path)
168+
checkpoint = torch.load(resume_path, map_location=torch.device(device))
153169

154170
if 'net' in checkpoint:
155-
net.load_state_dict(checkpoint['net'])
171+
load_state_dict(checkpoint['net'])
156172
best_acc = checkpoint['acc']
157173
start_epoch = checkpoint['epoch']
158174
Colors.cyan(f'==> Checkpoint found for epoch {start_epoch} with accuracy '
159175
f'{best_acc} at {resume_path}')
160176
else:
161-
net.load_state_dict(checkpoint)
177+
load_state_dict(checkpoint)
162178
Colors.cyan(f'==> Checkpoint found at {resume_path}')
163179

180+
164181
criterion = nn.CrossEntropyLoss()
165182
class_criterion = getattr(loss, args.loss)
166183
loss_kwargs = generate_kwargs(args, class_criterion,
@@ -270,6 +287,34 @@ def test(epoch, analyzer, checkpoint=True):
270287
analyzer = class_analysis(**analyzer_kwargs)
271288

272289

290+
if args.onnx:
291+
if not args.resume and not args.pretrained:
292+
Colors.red(' * Warning: Model is not loaded from checkpoint. '
293+
'Use --resume or --pretrained (if supported)')
294+
295+
fname = f"out/{checkpoint_fname}.onnx"
296+
dummy_input = torch.randn(1, 3, 32, 32)
297+
torch.onnx.export(
298+
net, dummy_input, fname,
299+
input_names=["x"], output_names=["outputs"])
300+
print(f"=> Wrote ONNX export to {fname}")
301+
302+
outputs_torch = net(dummy_input)
303+
outputs_torch = outputs_torch.detach().numpy()
304+
305+
ort_session = ort.InferenceSession(fname)
306+
outputs_onnx = ort_session.run(None, {
307+
'x': dummy_input.numpy()
308+
})
309+
310+
if (outputs_torch == outputs_onnx).all():
311+
Colors.green("=> ONNX export check succeeded: Outputs match.")
312+
else:
313+
Colors.red("=> ONNX export check failed: Outputs do not match.")
314+
315+
exit()
316+
317+
273318
if args.eval:
274319
if not args.resume and not args.pretrained:
275320
Colors.red(' * Warning: Model is not loaded from checkpoint. '

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ torchvision
44
nltk
55
scikit-learn
66
networkx
7+
onnxruntime

0 commit comments

Comments
 (0)