Skip to content

Commit 4ee788a

Browse files
feat(model): revert inference code on YOLO8
1 parent bcf5216 commit 4ee788a

File tree

4 files changed

+13
-16
lines changed

4 files changed

+13
-16
lines changed

connection/ModelPredictionsReceiver.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,10 @@ def predict(self, img: np.array) -> np.array:
4141
img[img.shape[0] // 2:, :, :]
4242
]
4343
preds = [self._predict(imgs[0]), self._predict(imgs[1])]
44-
print(preds)
4544
if len(preds[1]):
4645
preds[1][:, 1] += img.shape[0] // 2
4746
preds[1][:, 3] += img.shape[0] // 2
4847
result = np.append(*preds)
4948
if len(result.shape) == 1:
5049
result = np.expand_dims(result, 0)
51-
print(result.shape)
52-
print(result)
5350
return result

idle_model/IdleObjectDetectionModel.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,24 @@
22
import numpy as np
33
from transformers import DetrImageProcessor, DetrForObjectDetection
44
import torch
5+
from ultralytics import YOLO
56

67

78
class IdleObjectDetectionModel:
89
def __init__(self, path: str, conf_thresh, iou_thresh, classes) -> None:
9-
self.model = DetrForObjectDetection.from_pretrained(path)
10-
self.model.eval()
11-
self.processor = DetrImageProcessor.from_pretrained(path)
10+
self.model = YOLO(path)
1211
self.conf_thresh = conf_thresh
1312
self.iou_thresh = iou_thresh
1413
self.classes = classes
1514

1615
@torch.no_grad()
1716
def __call__(self, img: np.array) -> np.array:
18-
inputs = self.processor(images=img, return_tensors="pt")
19-
outputs = self.model(**inputs)
20-
target_sizes = torch.tensor([img.shape[:-1]])
21-
results = self.processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=self.conf_thresh)[0]
22-
pred = []
23-
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
24-
if self.model.config.id2label[label.item()] == 'cell phone':
25-
pred.append([round(i, 2) for i in box.tolist()] + [score.item()])
26-
return np.array(pred)
17+
output = self.model(
18+
source=img,
19+
conf=self.conf_thresh,
20+
iou=self.iou_thresh,
21+
max_det=600,
22+
classes=self.classes,
23+
verbose=False
24+
)[0].boxes
25+
return output.xyxy

idle_model/configs/confs.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@
44
],
55
"iou_thres": 0.5,
66
"conf_thres": 0.8,
7-
"model_path": "./weights",
7+
"model_path": "weights/yolov8l.pt",
88
"port": 5001
99
}

main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,4 @@
4343
img = utils.put_rectangle(img, preds[:, :4], preds[:, 4])
4444
reporter.send_report(reporter.create_report(img, str(start_tracking)))
4545
prev_preds = preds
46+
time.sleep(2)

0 commit comments

Comments
 (0)