first commit
This commit is contained in:
60
app/detector.py
Normal file
60
app/detector.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import DetrForObjectDetection, DetrImageProcessor
|
||||
|
||||
|
||||
class DetrVehicleDetector:
|
||||
def __init__(self, model_name: str, confidence: float, vehicle_labels: set[str]):
|
||||
self.confidence = confidence
|
||||
self.vehicle_labels = vehicle_labels
|
||||
self.device = self._select_device()
|
||||
self.processor = DetrImageProcessor.from_pretrained(model_name)
|
||||
self.model = DetrForObjectDetection.from_pretrained(model_name)
|
||||
self.model.to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
@staticmethod
|
||||
def _select_device() -> torch.device:
|
||||
if torch.backends.mps.is_available():
|
||||
return torch.device("mps")
|
||||
return torch.device("cpu")
|
||||
|
||||
@property
|
||||
def device_name(self) -> str:
|
||||
return self.device.type
|
||||
|
||||
@torch.no_grad()
|
||||
def detect(self, frame_rgb: Any) -> list[dict[str, Any]]:
|
||||
image = Image.fromarray(frame_rgb)
|
||||
inputs = self.processor(images=image, return_tensors="pt")
|
||||
inputs = {key: value.to(self.device) for key, value in inputs.items()}
|
||||
|
||||
outputs = self.model(**inputs)
|
||||
# DETR 后处理需要原图尺寸,PIL size 是 (宽, 高),这里转成 (高, 宽)。
|
||||
target_sizes = torch.tensor([image.size[::-1]], device=self.device)
|
||||
results = self.processor.post_process_object_detection(
|
||||
outputs,
|
||||
target_sizes=target_sizes,
|
||||
threshold=self.confidence,
|
||||
)[0]
|
||||
|
||||
detections = []
|
||||
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
|
||||
label_name = self.model.config.id2label[label.item()]
|
||||
if label_name not in self.vehicle_labels:
|
||||
continue
|
||||
|
||||
x1, y1, x2, y2 = box.detach().cpu().numpy().astype(int).tolist()
|
||||
detections.append(
|
||||
{
|
||||
"label": label_name,
|
||||
"score": round(float(score.detach().cpu()), 4),
|
||||
"box": [x1, y1, x2, y2],
|
||||
}
|
||||
)
|
||||
|
||||
return detections
|
||||
Reference in New Issue
Block a user