Skip to content

Commit b6e80be

Browse files
feat(idle_model): write code for getting inference from yolor model
1 parent 586eb04 commit b6e80be

File tree

5 files changed

+118
-15
lines changed

5 files changed

+118
-15
lines changed
Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,34 @@
11
import torch
2-
from ultralytics import YOLO
2+
from yolor.model import get_model
3+
import numpy as np
4+
from yolor.utils.datasets import letterbox
5+
from yolor.utils.general import non_max_suppression, scale_coords
36

47

58
class IdleObjectDetectionModel:
69
def __init__(self, path: str, conf_thresh, iou_thresh, classes) -> None:
7-
self.model = YOLO(path)
10+
self.model, self.device = get_model(path, "yolor/yolor_csp_x.cfg")
811
self.conf_thresh = conf_thresh
912
self.iou_thresh = iou_thresh
1013
self.classes = classes
1114

15+
def __preprocess_image__(self, img: np.array) -> np.array:
16+
self.img_shape = img.shape
17+
img = letterbox(img.copy(), new_shape=1280, auto_size=64)[0]
18+
img = img[:, :, ::-1].transpose(2, 0, 1)
19+
img = np.ascontiguousarray(img)
20+
img = torch.from_numpy(img).to(self.device)
21+
img = img.float()
22+
img /= 255.0
23+
img = img.unsqueeze(0)
24+
return img
25+
1226
@torch.no_grad()
13-
def __call__(self, img) -> list:
14-
results = self.model(
15-
source=img,
16-
conf=self.conf_thresh,
17-
iou=self.iou_thresh,
18-
max_det=600,
19-
classes=self.classes,
20-
verbose=False
21-
)[0].boxes
22-
return results.xyxy, results.conf
27+
def __call__(self, img: np.array) -> list:
28+
img = self.__preprocess_image__(img)
29+
pred = self.model(img, augment=False)[0]
30+
pred = non_max_suppression(
31+
pred, 0.45, 0.5, classes=[67], agnostic=False)[0]
32+
pred[:, :4] = scale_coords(
33+
img.shape[2:], pred[:, :4], self.img_shape).round()
34+
return pred[:, :4], pred[:, 4]

idle_models/requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,4 @@ pydantic==1.10.2
77
python-dotenv==1.0.0
88
PyYAML==6.0
99
requests==2.27.1
10-
ultralytics==8.0.112
1110
Flask==2.2.2

idle_models/yolor/coco.names

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
person
2+
bicycle
3+
car
4+
motorcycle
5+
airplane
6+
bus
7+
train
8+
truck
9+
boat
10+
traffic light
11+
fire hydrant
12+
stop sign
13+
parking meter
14+
bench
15+
bird
16+
cat
17+
dog
18+
horse
19+
sheep
20+
cow
21+
elephant
22+
bear
23+
zebra
24+
giraffe
25+
backpack
26+
umbrella
27+
handbag
28+
tie
29+
suitcase
30+
frisbee
31+
skis
32+
snowboard
33+
sports ball
34+
kite
35+
baseball bat
36+
baseball glove
37+
skateboard
38+
surfboard
39+
tennis racket
40+
bottle
41+
wine glass
42+
cup
43+
fork
44+
knife
45+
spoon
46+
bowl
47+
banana
48+
apple
49+
sandwich
50+
orange
51+
broccoli
52+
carrot
53+
hot dog
54+
pizza
55+
donut
56+
cake
57+
chair
58+
couch
59+
potted plant
60+
bed
61+
dining table
62+
toilet
63+
tv
64+
laptop
65+
mouse
66+
remote
67+
keyboard
68+
cell phone
69+
microwave
70+
oven
71+
toaster
72+
sink
73+
refrigerator
74+
book
75+
clock
76+
vase
77+
scissors
78+
teddy bear
79+
hair drier
80+
toothbrush

idle_models/yolor/model.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import torch
2+
from yolor.utils.torch_utils import select_device
3+
from yolor.models.models import Darknet
4+
5+
6+
def get_model(weights, cfg):
7+
imgsz = 1280
8+
device = select_device('cpu')
9+
model = Darknet(cfg, imgsz)
10+
model.load_state_dict(torch.load(weights, map_location=device)['model'])
11+
model.to(device).eval()
12+
return model, device

send_request.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ def predict(img: np.array, server_url: str, logger: Logger):
1818
confidences = np.array(response.json().get("confidences"))
1919
else:
2020
logger.warning(
21-
"Status code = {}\n JSON = {}".format(
22-
response.status_code, response.json())
21+
"Status code = {}\n response = {}".format(
22+
response.status_code, response)
2323
)
2424
coordinates = confidences = None
2525
return [coordinates, confidences]

0 commit comments

Comments
 (0)