Skip to content

Commit 6369053

Browse files
committed
fixed soft inference + nbdt script
1 parent 3810d21 commit 6369053

File tree

3 files changed

+119
-4
lines changed

3 files changed

+119
-4
lines changed

examples/load_pretrained_nbdts.ipynb

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"%pip install nbdt"
10+
]
11+
},
12+
{
13+
"cell_type": "code",
14+
"execution_count": null,
15+
"metadata": {},
16+
"outputs": [],
17+
"source": [
18+
"from nbdt.model import SoftNBDT\n",
19+
"from nbdt.models import ResNet18, wrn28_10_cifar10, wrn28_10_cifar100, wrn28_10 # use wrn28_10 for TinyImagenet200\n",
20+
"from torchvision import transforms\n",
21+
"from nbdt.utils import DATASET_TO_CLASSES, load_image_from_path, maybe_install_wordnet\n",
22+
"from IPython.display import display"
23+
]
24+
},
25+
{
26+
"cell_type": "code",
27+
"execution_count": null,
28+
"metadata": {},
29+
"outputs": [],
30+
"source": [
31+
"model = wrn28_10_cifar10()\n",
32+
"model = SoftNBDT(\n",
33+
" pretrained=True,\n",
34+
" dataset='CIFAR10',\n",
35+
" arch='wrn28_10_cifar10',\n",
36+
" model=model)"
37+
]
38+
},
39+
{
40+
"cell_type": "code",
41+
"execution_count": null,
42+
"metadata": {},
43+
"outputs": [],
44+
"source": [
45+
"image_urls = {\n",
46+
" 'cat': 'https://images.pexels.com/photos/126407/pexels-photo-126407.jpeg?auto=compress&cs=tinysrgb&dpr=2&w=300',\n",
47+
" 'bear': 'https://images.pexels.com/photos/1466592/pexels-photo-1466592.jpeg?auto=compress&cs=tinysrgb&dpr=2&w=300',\n",
48+
" 'dog': 'https://images.pexels.com/photos/1490908/pexels-photo-1490908.jpeg?auto=compress&cs=tinysrgb&dpr=2&w=300'\n",
49+
"}"
50+
]
51+
},
52+
{
53+
"cell_type": "code",
54+
"execution_count": null,
55+
"metadata": {},
56+
"outputs": [],
57+
"source": [
58+
"# show image\n",
59+
"im = load_image_from_path(image_urls['cat'])\n",
60+
"display(im)"
61+
]
62+
},
63+
{
64+
"cell_type": "code",
65+
"execution_count": null,
66+
"metadata": {},
67+
"outputs": [],
68+
"source": [
69+
"# load + transform image\n",
70+
"transforms = transforms.Compose([\n",
71+
" transforms.Resize(32),\n",
72+
" transforms.CenterCrop(32),\n",
73+
" transforms.ToTensor(),\n",
74+
" transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
75+
"])\n",
76+
"x = transforms(im)[None]"
77+
]
78+
},
79+
{
80+
"cell_type": "code",
81+
"execution_count": null,
82+
"metadata": {},
83+
"outputs": [],
84+
"source": [
85+
"# run inference\n",
86+
"outputs = model(x) # to get intermediate decisions, use `model.forward_with_decisions(x)` and add `hierarchy='wordnet' to SoftNBDT\n",
87+
"_, predicted = outputs.max(1)\n",
88+
"cls = DATASET_TO_CLASSES['CIFAR10'][predicted[0]]\n",
89+
"print(cls)"
90+
]
91+
}
92+
],
93+
"metadata": {
94+
"kernelspec": {
95+
"display_name": "pytorch-1.2",
96+
"language": "python",
97+
"name": "pytorch-1.2"
98+
},
99+
"language_info": {
100+
"codemirror_mode": {
101+
"name": "ipython",
102+
"version": 3
103+
},
104+
"file_extension": ".py",
105+
"mimetype": "text/x-python",
106+
"name": "python",
107+
"nbconvert_exporter": "python",
108+
"pygments_lexer": "ipython3",
109+
"version": "3.7.4"
110+
}
111+
},
112+
"nbformat": 4,
113+
"nbformat_minor": 4
114+
}

nbdt/analysis.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,8 @@ class SoftEmbeddedDecisionRules(HardEmbeddedDecisionRules):
242242
name = 'NBDT-Soft'
243243

244244
def forward_with_decisions(self, outputs):
245-
predicted = self.forward(outputs)
245+
outputs = self.forward(outputs)
246+
_, predicted = outputs.max(1)
246247

247248
decisions = []
248249
node = self.nodes[0]
@@ -251,7 +252,7 @@ def forward_with_decisions(self, outputs):
251252
leaf = node.wnids[prediction]
252253
decision = leaf_to_path_nodes[leaf]
253254
decisions.append(decision)
254-
return predicted, decisions
255+
return outputs, decisions
255256

256257
def forward(self, outputs):
257258
outputs = SoftTreeSupLoss.inference(

nbdt/bin/nbdt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ x = transforms(im)[None]
3232

3333
# run inference
3434
outputs, decisions = model.forward_with_decisions(x) # use `model(x)` to obtain just logits
35-
outputs = [3]
36-
cls = DATASET_TO_CLASSES['CIFAR10'][outputs[0]]
35+
_, predicted = outputs.max(1)
36+
cls = DATASET_TO_CLASSES['CIFAR10'][predicted[0]]
3737
print('Prediction:', cls, '// Decisions:', ', '.join([
3838
info['name'] for info in decisions[0]
3939
][1:])) # [1:] to skip the root

0 commit comments

Comments
 (0)