"""DETR 车辆检测模块:加载模型并输出车辆类别、置信度和检测框。""" from __future__ import annotations from typing import Any import torch from PIL import Image from transformers import DetrForObjectDetection, DetrImageProcessor class DetrVehicleDetector: def __init__(self, model_name: str, confidence: float, vehicle_labels: set[str]): self.confidence = confidence self.vehicle_labels = vehicle_labels self.device = self._select_device() self.processor = DetrImageProcessor.from_pretrained(model_name) self.model = DetrForObjectDetection.from_pretrained(model_name) self.model.to(self.device) self.model.eval() @staticmethod def _select_device() -> torch.device: if torch.backends.mps.is_available(): return torch.device("mps") return torch.device("cpu") @property def device_name(self) -> str: return self.device.type @torch.no_grad() def detect(self, frame_rgb: Any) -> list[dict[str, Any]]: image = Image.fromarray(frame_rgb) inputs = self.processor(images=image, return_tensors="pt") inputs = {key: value.to(self.device) for key, value in inputs.items()} outputs = self.model(**inputs) # DETR 后处理需要原图尺寸,PIL size 是 (宽, 高),这里转成 (高, 宽)。 target_sizes = torch.tensor([image.size[::-1]], device=self.device) results = self.processor.post_process_object_detection( outputs, target_sizes=target_sizes, threshold=self.confidence, )[0] detections = [] for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): label_name = self.model.config.id2label[label.item()] if label_name not in self.vehicle_labels: continue x1, y1, x2, y2 = box.detach().cpu().numpy().astype(int).tolist() detections.append( { "label": label_name, "score": round(float(score.detach().cpu()), 4), "box": [x1, y1, x2, y2], } ) return detections