|
1 | 1 | import torch
|
2 |
| -from yolor.model import get_model |
3 | 2 | import numpy as np
|
4 |
| -from yolor.utils.datasets import letterbox |
5 |
| -from yolor.utils.general import non_max_suppression, scale_coords |
| 3 | +from transformers import DetrImageProcessor, DetrForObjectDetection, DetrConfig |
| 4 | +import torch |
6 | 5 |
|
7 | 6 |
|
8 | 7 | class IdleObjectDetectionModel:
|
9 | 8 | def __init__(self, path: str, conf_thresh, iou_thresh, classes) -> None:
|
10 |
| - self.model, self.device = get_model(path, "yolor/yolor_csp_x.cfg") |
| 9 | + self.model = DetrForObjectDetection.from_pretrained(path) |
| 10 | + self.model.eval() |
| 11 | + self.processor = DetrImageProcessor.from_pretrained(path) |
11 | 12 | self.conf_thresh = conf_thresh
|
12 | 13 | self.iou_thresh = iou_thresh
|
13 | 14 | self.classes = classes
|
14 | 15 |
|
15 |
| - def __preprocess_image__(self, img: np.array) -> np.array: |
16 |
| - self.img_shape = img.shape |
17 |
| - img = letterbox(img.copy(), new_shape=1280, auto_size=64)[0] |
18 |
| - img = img[:, :, ::-1].transpose(2, 0, 1) |
19 |
| - img = np.ascontiguousarray(img) |
20 |
| - img = torch.from_numpy(img).to(self.device) |
21 |
| - img = img.float() |
22 |
| - img /= 255.0 |
23 |
| - img = img.unsqueeze(0) |
24 |
| - return img |
25 |
| - |
26 | 16 | @torch.no_grad()
|
27 |
| - def __call__(self, img: np.array) -> torch.Tensor: |
28 |
| - img = self.__preprocess_image__(img) |
29 |
| - pred = self.model(img, augment=False)[0] |
30 |
| - pred = non_max_suppression( |
31 |
| - pred, 0.45, 0.5, classes=[67], agnostic=False)[0] |
32 |
| - pred[:, :4] = scale_coords( |
33 |
| - img.shape[2:], pred[:, :4], self.img_shape).round() |
34 |
| - return pred[:, :5] |
| 17 | + def __call__(self, img: np.array) -> np.array: |
| 18 | + inputs = self.processor(images=img, return_tensors="pt") |
| 19 | + outputs = self.model(**inputs) |
| 20 | + target_sizes = torch.tensor([img.shape[:-1]]) |
| 21 | + results = self.processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=self.conf_thresh)[0] |
| 22 | + pred = [] |
| 23 | + for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): |
| 24 | + if self.model.config.id2label[label.item()] == 'cell phone': |
| 25 | + pred.append([round(i, 2) for i in box.tolist()] + [score.item()]) |
| 26 | + return np.array(pred) |
0 commit comments