"""DETR 车辆检测模块:加载模型并输出车辆类别、置信度和检测框。""" from __future__ import annotations from typing import Any import torch from PIL import Image from transformers import DetrForObjectDetection, DetrImageProcessor class DetrVehicleDetector: def __init__(self, model_name: str, confidence: float, vehicle_labels: set[str]): self.confidence = confidence self.vehicle_labels = vehicle_labels self.device = self._select_device() self.processor = DetrImageProcessor.from_pretrained(model_name) self.model = DetrForObjectDetection.from_pretrained(model_name) self.model.to(self.device) self.model.eval() @staticmethod def _select_device() -> torch.device: if torch.backends.mps.is_available(): return torch.device("mps") return torch.device("cpu") @property def device_name(self) -> str: return self.device.type @torch.no_grad() def detect(self, frame_rgb: Any) -> list[dict[str, Any]]: image = Image.fromarray(frame_rgb) inputs = self.processor(images=image, return_tensors="pt") inputs = {key: value.to(self.device) for key, value in inputs.items()} outputs = self.model(**inputs) return self._detections_from_outputs(image, outputs) @torch.no_grad() def inspect_tokens(self, frame_rgb: Any) -> dict[str, Any]: image = Image.fromarray(frame_rgb) inputs = self.processor(images=image, return_tensors="pt") inputs = {key: value.to(self.device) for key, value in inputs.items()} features, object_queries_list = self.model.model.backbone(inputs["pixel_values"], inputs["pixel_mask"]) feature_map, mask = features[-1] projected_feature_map = self.model.model.input_projection(feature_map) tokens = projected_feature_map.flatten(2).permute(0, 2, 1) object_queries = object_queries_list[-1].flatten(2).permute(0, 2, 1) outputs = self.model(**inputs, output_hidden_states=True) target_sizes = torch.tensor([image.size[::-1]], device=self.device) results = self.processor.post_process_object_detection( outputs, target_sizes=target_sizes, threshold=self.confidence, )[0] token_rows = int(projected_feature_map.shape[2]) token_cols = int(projected_feature_map.shape[3]) sample_count = min(48, int(tokens.shape[1])) sample_tokens = tokens[0, :sample_count, :8].detach().cpu() token_sequence = [] for index, vector in enumerate(sample_tokens): token_sequence.append( { "index": index, "row": index // token_cols, "col": index % token_cols, "values": [round(float(value), 4) for value in vector.tolist()], "magnitude": round(float(vector.norm()), 4), } ) detections = [] for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): label_name = self.model.config.id2label[label.item()] if label_name not in self.vehicle_labels: continue x1, y1, x2, y2 = box.detach().cpu().numpy().astype(int).tolist() detections.append( { "label": label_name, "score": round(float(score.detach().cpu()), 4), "box": [x1, y1, x2, y2], } ) encoder_last_hidden_state = getattr(outputs, "encoder_last_hidden_state", None) last_hidden_state = getattr(outputs, "last_hidden_state", None) return { "image_size": {"width": image.size[0], "height": image.size[1]}, "pixel_values_shape": list(inputs["pixel_values"].shape), "pixel_mask_shape": list(inputs["pixel_mask"].shape), "feature_map_shape": list(feature_map.shape), "projected_feature_map_shape": list(projected_feature_map.shape), "visual_tokens_shape": list(tokens.shape), "position_encoding_shape": list(object_queries.shape), "encoder_last_hidden_state_shape": list(encoder_last_hidden_state.shape) if encoder_last_hidden_state is not None else [], "decoder_last_hidden_state_shape": list(last_hidden_state.shape) if last_hidden_state is not None else [], "logits_shape": list(outputs.logits.shape), "pred_boxes_shape": list(outputs.pred_boxes.shape), "token_grid": {"rows": token_rows, "cols": token_cols, "total": int(tokens.shape[1]), "shown": sample_count}, "token_sequence": token_sequence, "detections": detections, } def _detections_from_outputs(self, image: Image.Image, outputs: Any) -> list[dict[str, Any]]: target_sizes = torch.tensor([image.size[::-1]], device=self.device) results = self.processor.post_process_object_detection( outputs, target_sizes=target_sizes, threshold=self.confidence, )[0] detections = [] for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): label_name = self.model.config.id2label[label.item()] if label_name not in self.vehicle_labels: continue x1, y1, x2, y2 = box.detach().cpu().numpy().astype(int).tolist() detections.append( { "label": label_name, "score": round(float(score.detach().cpu()), 4), "box": [x1, y1, x2, y2], } ) return detections