Files
tokenresearch/app/detector.py
2026-06-03 11:00:50 +08:00

61 lines
2.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
from typing import Any
import torch
from PIL import Image
from transformers import DetrForObjectDetection, DetrImageProcessor
class DetrVehicleDetector:
def __init__(self, model_name: str, confidence: float, vehicle_labels: set[str]):
self.confidence = confidence
self.vehicle_labels = vehicle_labels
self.device = self._select_device()
self.processor = DetrImageProcessor.from_pretrained(model_name)
self.model = DetrForObjectDetection.from_pretrained(model_name)
self.model.to(self.device)
self.model.eval()
@staticmethod
def _select_device() -> torch.device:
if torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")
@property
def device_name(self) -> str:
return self.device.type
@torch.no_grad()
def detect(self, frame_rgb: Any) -> list[dict[str, Any]]:
image = Image.fromarray(frame_rgb)
inputs = self.processor(images=image, return_tensors="pt")
inputs = {key: value.to(self.device) for key, value in inputs.items()}
outputs = self.model(**inputs)
# DETR 后处理需要原图尺寸PIL size 是 (宽, 高),这里转成 (高, 宽)。
target_sizes = torch.tensor([image.size[::-1]], device=self.device)
results = self.processor.post_process_object_detection(
outputs,
target_sizes=target_sizes,
threshold=self.confidence,
)[0]
detections = []
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
label_name = self.model.config.id2label[label.item()]
if label_name not in self.vehicle_labels:
continue
x1, y1, x2, y2 = box.detach().cpu().numpy().astype(int).tolist()
detections.append(
{
"label": label_name,
"score": round(float(score.detach().cpu()), 4),
"box": [x1, y1, x2, y2],
}
)
return detections