|
11 | 11 |
|
12 | 12 | import os
|
13 | 13 | import argparse
|
| 14 | +import onnxruntime as ort |
| 15 | +import numpy as np |
14 | 16 |
|
15 | 17 | from nbdt.utils import (
|
16 | 18 | progress_bar, generate_fname, generate_kwargs, Colors, maybe_install_wordnet
|
|
39 | 41 | parser.add_argument('--pretrained', action='store_true',
|
40 | 42 | help='Download pretrained model. Not all models support this.')
|
41 | 43 | parser.add_argument('--eval', help='eval only', action='store_true')
|
| 44 | +parser.add_argument('--onnx', help='export only', action='store_true') |
42 | 45 |
|
43 | 46 | # options specific to this project and its dataloaders
|
44 | 47 | parser.add_argument('--loss', choices=loss.names, default='CrossEntropyLoss')
|
|
141 | 144 | checkpoint_path = './checkpoint/{}.pth'.format(checkpoint_fname)
|
142 | 145 | print(f'==> Checkpoints will be saved to: {checkpoint_path}')
|
143 | 146 |
|
| 147 | + |
| 148 | +# TODO(alvin): fix checkpoint structure so that this isn't neededd |
| 149 | +def load_state_dict(state_dict): |
| 150 | + try: |
| 151 | + net.load_state_dict(state_dict) |
| 152 | + except RuntimeError as e: |
| 153 | + if 'Missing key(s) in state_dict:' in str(e): |
| 154 | + net.load_state_dict({ |
| 155 | + key.replace('module.', '', 1): value |
| 156 | + for key, value in state_dict.items() |
| 157 | + }) |
| 158 | + |
| 159 | + |
144 | 160 | resume_path = args.path_resume or checkpoint_path
|
145 | 161 | if args.resume:
|
146 | 162 | # Load checkpoint.
|
|
149 | 165 | if not os.path.exists(resume_path):
|
150 | 166 | print('==> No checkpoint found. Skipping...')
|
151 | 167 | else:
|
152 |
| - checkpoint = torch.load(resume_path) |
| 168 | + checkpoint = torch.load(resume_path, map_location=torch.device(device)) |
153 | 169 |
|
154 | 170 | if 'net' in checkpoint:
|
155 |
| - net.load_state_dict(checkpoint['net']) |
| 171 | + load_state_dict(checkpoint['net']) |
156 | 172 | best_acc = checkpoint['acc']
|
157 | 173 | start_epoch = checkpoint['epoch']
|
158 | 174 | Colors.cyan(f'==> Checkpoint found for epoch {start_epoch} with accuracy '
|
159 | 175 | f'{best_acc} at {resume_path}')
|
160 | 176 | else:
|
161 |
| - net.load_state_dict(checkpoint) |
| 177 | + load_state_dict(checkpoint) |
162 | 178 | Colors.cyan(f'==> Checkpoint found at {resume_path}')
|
163 | 179 |
|
| 180 | + |
164 | 181 | criterion = nn.CrossEntropyLoss()
|
165 | 182 | class_criterion = getattr(loss, args.loss)
|
166 | 183 | loss_kwargs = generate_kwargs(args, class_criterion,
|
@@ -270,6 +287,34 @@ def test(epoch, analyzer, checkpoint=True):
|
270 | 287 | analyzer = class_analysis(**analyzer_kwargs)
|
271 | 288 |
|
272 | 289 |
|
| 290 | +if args.onnx: |
| 291 | + if not args.resume and not args.pretrained: |
| 292 | + Colors.red(' * Warning: Model is not loaded from checkpoint. ' |
| 293 | + 'Use --resume or --pretrained (if supported)') |
| 294 | + |
| 295 | + fname = f"out/{checkpoint_fname}.onnx" |
| 296 | + dummy_input = torch.randn(1, 3, 32, 32) |
| 297 | + torch.onnx.export( |
| 298 | + net, dummy_input, fname, |
| 299 | + input_names=["x"], output_names=["outputs"]) |
| 300 | + print(f"=> Wrote ONNX export to {fname}") |
| 301 | + |
| 302 | + outputs_torch = net(dummy_input) |
| 303 | + outputs_torch = outputs_torch.detach().numpy() |
| 304 | + |
| 305 | + ort_session = ort.InferenceSession(fname) |
| 306 | + outputs_onnx = ort_session.run(None, { |
| 307 | + 'x': dummy_input.numpy() |
| 308 | + }) |
| 309 | + |
| 310 | + if (outputs_torch == outputs_onnx).all(): |
| 311 | + Colors.green("=> ONNX export check succeeded: Outputs match.") |
| 312 | + else: |
| 313 | + Colors.red("=> ONNX export check failed: Outputs do not match.") |
| 314 | + |
| 315 | + exit() |
| 316 | + |
| 317 | + |
273 | 318 | if args.eval:
|
274 | 319 | if not args.resume and not args.pretrained:
|
275 | 320 | Colors.red(' * Warning: Model is not loaded from checkpoint. '
|
|
0 commit comments