Skip to content

Commit c8926e0

Browse files
committed
siompler sample script
1 parent 71555e3 commit c8926e0

File tree

2 files changed

+28
-20
lines changed

2 files changed

+28
-20
lines changed

nbdt/bin/nbdt

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,12 @@
33

44
from nbdt.model import SoftNBDT
55
from pytorchcv.models.wrn_cifar import wrn28_10_cifar10
6-
from PIL import Image
7-
from urllib.request import urlopen, Request
86
from torchvision import transforms
9-
import io
7+
from nbdt.utils import DATASET_TO_CLASSES
108
import sys
119

1210
assert len(sys.argv) > 1, "Need to pass image URL or image path as argument"
1311

14-
classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
15-
1612
# load pretrained NBDT
1713
model = wrn28_10_cifar10()
1814
model = SoftNBDT(
@@ -22,27 +18,17 @@ model = SoftNBDT(
2218
pretrained=True,
2319
arch='wrn28_10_cifar10')
2420

25-
# load image
26-
path = sys.argv[1]
27-
headers = {
28-
'User-Agent': 'Mozilla/5.0 (Windows NT 6.1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/41.0.2228.0 Safari/537.3'
29-
}
30-
if 'http' in path:
31-
request = Request(path, headers=headers)
32-
file = io.BytesIO(urlopen(request).read())
33-
else:
34-
file = path
35-
im = Image.open(file)
36-
37-
# transform image
21+
# load + transform image
22+
im = load_image_from_path(sys.argv[1])
3823
transforms = transforms.Compose([
3924
transforms.Resize(32),
4025
transforms.CenterCrop(32),
41-
transforms.ToTensor()
26+
transforms.ToTensor(),
27+
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
4228
])
4329
x = transforms(im)[None]
4430

4531
# run inference
4632
outputs = model(x)
47-
cls = classes[outputs[0]]
33+
cls = DATASET_TO_CLASSES['CIFAR10'][outputs[0]]
4834
print(cls)

nbdt/utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,12 @@
99
import math
1010
import numpy as np
1111

12+
from urllib.request import urlopen, Request
13+
from PIL import Image
1214
import torch.nn as nn
1315
import torch.nn.init as init
1416
from pathlib import Path
17+
import io
1518

1619
# tree-generation consntants
1720
METHODS = ('wordnet', 'random', 'induced')
@@ -22,6 +25,12 @@
2225
'TinyImagenet200': 200,
2326
'Imagenet1000': 1000
2427
}
28+
DATASET_TO_CLASSES = {
29+
'CIFAR10': [
30+
'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog',
31+
'horse', 'ship', 'truck'
32+
]
33+
}
2534

2635

2736
def fwd():
@@ -61,6 +70,19 @@ def populate_kwargs(args, kwargs, object, name='Dataset', keys=(), globals={}):
6170
f'{key}: {value}')
6271

6372

73+
def load_image_from_path(path):
74+
"""Path can be local or a URL"""
75+
headers = {
76+
'User-Agent': 'Mozilla/5.0 (Windows NT 6.1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/41.0.2228.0 Safari/537.3'
77+
}
78+
if 'http' in path:
79+
request = Request(path, headers=headers)
80+
file = io.BytesIO(urlopen(request).read())
81+
else:
82+
file = path
83+
return Image.open(file)
84+
85+
6486
class Colors:
6587
RED = '\x1b[31m'
6688
GREEN = '\x1b[32m'

0 commit comments

Comments
 (0)