136 lines
5.6 KiB
Python
136 lines
5.6 KiB
Python
"""DETR 车辆检测模块:加载模型并输出车辆类别、置信度和检测框。"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Any
|
|
|
|
import torch
|
|
from PIL import Image
|
|
from transformers import DetrForObjectDetection, DetrImageProcessor
|
|
|
|
|
|
class DetrVehicleDetector:
|
|
def __init__(self, model_name: str, confidence: float, vehicle_labels: set[str]):
|
|
self.confidence = confidence
|
|
self.vehicle_labels = vehicle_labels
|
|
self.device = self._select_device()
|
|
self.processor = DetrImageProcessor.from_pretrained(model_name)
|
|
self.model = DetrForObjectDetection.from_pretrained(model_name)
|
|
self.model.to(self.device)
|
|
self.model.eval()
|
|
|
|
@staticmethod
|
|
def _select_device() -> torch.device:
|
|
if torch.backends.mps.is_available():
|
|
return torch.device("mps")
|
|
return torch.device("cpu")
|
|
|
|
@property
|
|
def device_name(self) -> str:
|
|
return self.device.type
|
|
|
|
@torch.no_grad()
|
|
def detect(self, frame_rgb: Any) -> list[dict[str, Any]]:
|
|
image = Image.fromarray(frame_rgb)
|
|
inputs = self.processor(images=image, return_tensors="pt")
|
|
inputs = {key: value.to(self.device) for key, value in inputs.items()}
|
|
|
|
outputs = self.model(**inputs)
|
|
return self._detections_from_outputs(image, outputs)
|
|
|
|
@torch.no_grad()
|
|
def inspect_tokens(self, frame_rgb: Any) -> dict[str, Any]:
|
|
image = Image.fromarray(frame_rgb)
|
|
inputs = self.processor(images=image, return_tensors="pt")
|
|
inputs = {key: value.to(self.device) for key, value in inputs.items()}
|
|
|
|
features, object_queries_list = self.model.model.backbone(inputs["pixel_values"], inputs["pixel_mask"])
|
|
feature_map, mask = features[-1]
|
|
projected_feature_map = self.model.model.input_projection(feature_map)
|
|
tokens = projected_feature_map.flatten(2).permute(0, 2, 1)
|
|
object_queries = object_queries_list[-1].flatten(2).permute(0, 2, 1)
|
|
|
|
outputs = self.model(**inputs, output_hidden_states=True)
|
|
target_sizes = torch.tensor([image.size[::-1]], device=self.device)
|
|
results = self.processor.post_process_object_detection(
|
|
outputs,
|
|
target_sizes=target_sizes,
|
|
threshold=self.confidence,
|
|
)[0]
|
|
|
|
token_rows = int(projected_feature_map.shape[2])
|
|
token_cols = int(projected_feature_map.shape[3])
|
|
sample_count = min(48, int(tokens.shape[1]))
|
|
sample_tokens = tokens[0, :sample_count, :8].detach().cpu()
|
|
token_sequence = []
|
|
for index, vector in enumerate(sample_tokens):
|
|
token_sequence.append(
|
|
{
|
|
"index": index,
|
|
"row": index // token_cols,
|
|
"col": index % token_cols,
|
|
"values": [round(float(value), 4) for value in vector.tolist()],
|
|
"magnitude": round(float(vector.norm()), 4),
|
|
}
|
|
)
|
|
|
|
detections = []
|
|
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
|
|
label_name = self.model.config.id2label[label.item()]
|
|
if label_name not in self.vehicle_labels:
|
|
continue
|
|
|
|
x1, y1, x2, y2 = box.detach().cpu().numpy().astype(int).tolist()
|
|
detections.append(
|
|
{
|
|
"label": label_name,
|
|
"score": round(float(score.detach().cpu()), 4),
|
|
"box": [x1, y1, x2, y2],
|
|
}
|
|
)
|
|
|
|
encoder_last_hidden_state = getattr(outputs, "encoder_last_hidden_state", None)
|
|
last_hidden_state = getattr(outputs, "last_hidden_state", None)
|
|
return {
|
|
"image_size": {"width": image.size[0], "height": image.size[1]},
|
|
"pixel_values_shape": list(inputs["pixel_values"].shape),
|
|
"pixel_mask_shape": list(inputs["pixel_mask"].shape),
|
|
"feature_map_shape": list(feature_map.shape),
|
|
"projected_feature_map_shape": list(projected_feature_map.shape),
|
|
"visual_tokens_shape": list(tokens.shape),
|
|
"position_encoding_shape": list(object_queries.shape),
|
|
"encoder_last_hidden_state_shape": list(encoder_last_hidden_state.shape) if encoder_last_hidden_state is not None else [],
|
|
"decoder_last_hidden_state_shape": list(last_hidden_state.shape) if last_hidden_state is not None else [],
|
|
"logits_shape": list(outputs.logits.shape),
|
|
"pred_boxes_shape": list(outputs.pred_boxes.shape),
|
|
"token_grid": {"rows": token_rows, "cols": token_cols, "total": int(tokens.shape[1]), "shown": sample_count},
|
|
"token_sequence": token_sequence,
|
|
"detections": detections,
|
|
}
|
|
|
|
def _detections_from_outputs(self, image: Image.Image, outputs: Any) -> list[dict[str, Any]]:
|
|
target_sizes = torch.tensor([image.size[::-1]], device=self.device)
|
|
results = self.processor.post_process_object_detection(
|
|
outputs,
|
|
target_sizes=target_sizes,
|
|
threshold=self.confidence,
|
|
)[0]
|
|
|
|
detections = []
|
|
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
|
|
label_name = self.model.config.id2label[label.item()]
|
|
if label_name not in self.vehicle_labels:
|
|
continue
|
|
|
|
x1, y1, x2, y2 = box.detach().cpu().numpy().astype(int).tolist()
|
|
detections.append(
|
|
{
|
|
"label": label_name,
|
|
"score": round(float(score.detach().cpu()), 4),
|
|
"box": [x1, y1, x2, y2],
|
|
}
|
|
)
|
|
|
|
return detections
|
|
|