Skip to content

Commit 38b2812

Browse files
committed
example app
1 parent 0f23c98 commit 38b2812

File tree

5 files changed

+58
-2
lines changed

5 files changed

+58
-2
lines changed

examples/app/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
super simple flask app for serving NBDT predictions
2+
3+
deployed to repl.it (link tbd)

examples/app/app.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from flask import Flask
2+
from nbdt.model import SoftNBDT
3+
from nbdt.models import wrn28_10_cifar10
4+
from torchvision import transforms
5+
from nbdt.utils import DATASET_TO_CLASSES, load_image_from_path, maybe_install_wordnet
6+
7+
app = Flask(__name__)
8+
9+
10+
@app.route('/')
11+
def home():
12+
# load pretrained NBDT
13+
model = wrn28_10_cifar10()
14+
model = SoftNBDT(
15+
pretrained=True,
16+
dataset='CIFAR10',
17+
arch='wrn28_10_cifar10',
18+
hierarchy='wordnet',
19+
model=model)
20+
21+
# load + transform image
22+
im = load_image_from_path("https://images.pexels.com/photos/1170986/pexels-photo-1170986.jpeg?auto=compress&cs=tinysrgb&dpr=2&w=32")
23+
transform = transforms.Compose([
24+
transforms.Resize(32),
25+
transforms.CenterCrop(32),
26+
transforms.ToTensor(),
27+
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
28+
])
29+
x = transform(im)[None]
30+
31+
# run inference
32+
outputs, decisions = model.forward_with_decisions(x) # use `model(x)` to obtain just logits
33+
_, predicted = outputs.max(1)
34+
return {
35+
'predicted': [DATASET_TO_CLASSES['CIFAR10'][pred] for pred in predicted],
36+
'decisions': [[info['name'] for info in decision] for decision in decisions]
37+
}
38+
39+
40+
if __name__ == '__main__':
41+
app.run(host='0.0.0.0', port=8080)

examples/app/pyproject.toml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
[tool]
2+
[tool.poetry]
3+
name = "nbdt-demo"
4+
version = "0.0.1"
5+
description = ""
6+
authors = ["Alvin Wan <hi@alvinwan.com>"]
7+
[tool.poetry.dependencies]
8+
python = "^3.7"
9+
nbdt = "^0.0.1"
10+
flask = "^1.1.1"

examples/app/requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
nbdt==0.0.1
2+
flask==1.1.1

nbdt/bin/nbdt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@ model = SoftNBDT(
2222

2323
# load + transform image
2424
im = load_image_from_path(sys.argv[1])
25-
transforms = transforms.Compose([
25+
transform = transforms.Compose([
2626
transforms.Resize(32),
2727
transforms.CenterCrop(32),
2828
transforms.ToTensor(),
2929
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
3030
])
31-
x = transforms(im)[None]
31+
x = transform(im)[None]
3232

3333
# run inference
3434
outputs, decisions = model.forward_with_decisions(x) # use `model(x)` to obtain just logits

0 commit comments

Comments
 (0)