63 lines
2.2 KiB
Python
63 lines
2.2 KiB
Python
"""DETR 车辆检测模块:加载模型并输出车辆类别、置信度和检测框。"""
|
||
|
||
from __future__ import annotations
|
||
|
||
from typing import Any
|
||
|
||
import torch
|
||
from PIL import Image
|
||
from transformers import DetrForObjectDetection, DetrImageProcessor
|
||
|
||
|
||
class DetrVehicleDetector:
|
||
def __init__(self, model_name: str, confidence: float, vehicle_labels: set[str]):
|
||
self.confidence = confidence
|
||
self.vehicle_labels = vehicle_labels
|
||
self.device = self._select_device()
|
||
self.processor = DetrImageProcessor.from_pretrained(model_name)
|
||
self.model = DetrForObjectDetection.from_pretrained(model_name)
|
||
self.model.to(self.device)
|
||
self.model.eval()
|
||
|
||
@staticmethod
|
||
def _select_device() -> torch.device:
|
||
if torch.backends.mps.is_available():
|
||
return torch.device("mps")
|
||
return torch.device("cpu")
|
||
|
||
@property
|
||
def device_name(self) -> str:
|
||
return self.device.type
|
||
|
||
@torch.no_grad()
|
||
def detect(self, frame_rgb: Any) -> list[dict[str, Any]]:
|
||
image = Image.fromarray(frame_rgb)
|
||
inputs = self.processor(images=image, return_tensors="pt")
|
||
inputs = {key: value.to(self.device) for key, value in inputs.items()}
|
||
|
||
outputs = self.model(**inputs)
|
||
# DETR 后处理需要原图尺寸,PIL size 是 (宽, 高),这里转成 (高, 宽)。
|
||
target_sizes = torch.tensor([image.size[::-1]], device=self.device)
|
||
results = self.processor.post_process_object_detection(
|
||
outputs,
|
||
target_sizes=target_sizes,
|
||
threshold=self.confidence,
|
||
)[0]
|
||
|
||
detections = []
|
||
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
|
||
label_name = self.model.config.id2label[label.item()]
|
||
if label_name not in self.vehicle_labels:
|
||
continue
|
||
|
||
x1, y1, x2, y2 = box.detach().cpu().numpy().astype(int).tolist()
|
||
detections.append(
|
||
{
|
||
"label": label_name,
|
||
"score": round(float(score.detach().cpu()), 4),
|
||
"box": [x1, y1, x2, y2],
|
||
}
|
||
)
|
||
|
||
return detections
|