first commit
This commit is contained in:
13
.env.example
Normal file
13
.env.example
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
API_BASE_URL=https://apicapacity.51iwifi.com
|
||||||
|
APP_ID=your_app_id
|
||||||
|
APP_SECRET=your_app_secret
|
||||||
|
DEVICE_LIST_PATH=devicelist.env
|
||||||
|
DEVICE_ACCOUNT=21cn
|
||||||
|
STREAM_METHOD=capacity.geye.device.devUrl.get
|
||||||
|
STREAM_URL=
|
||||||
|
DETR_CONFIDENCE=0.6
|
||||||
|
FRAME_SKIP=5
|
||||||
|
DETR_MODEL=facebook/detr-resnet-50
|
||||||
|
JPEG_QUALITY=80
|
||||||
|
RESIZE_WIDTH=960
|
||||||
|
VEHICLE_LABELS=car,motorcycle,bus,truck,bicycle
|
||||||
15
.gitignore
vendored
Normal file
15
.gitignore
vendored
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
.venv/
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
.pytest_cache/
|
||||||
|
.mypy_cache/
|
||||||
|
.ruff_cache/
|
||||||
|
.coverage
|
||||||
|
htmlcov/
|
||||||
|
build/
|
||||||
|
dist/
|
||||||
|
*.egg-info/
|
||||||
|
.env
|
||||||
|
devicelist.env
|
||||||
|
.DS_Store
|
||||||
137
README.md
Normal file
137
README.md
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
# macOS DETR 车辆动态打标
|
||||||
|
|
||||||
|
这是一个参考 `../VideoPipe` 思路实现的 macOS 友好版 Python 项目:
|
||||||
|
|
||||||
|
```text
|
||||||
|
RTSP/HLS/本地文件源节点 -> DETR 车辆检测推理节点 -> OSD 画框节点 -> FastAPI 远程输出节点
|
||||||
|
```
|
||||||
|
|
||||||
|
本项目不直接移植 VideoPipe 的 C++、TensorRT、CUDA、GStreamer RTSP Server 工程,而是保留它的“源节点 → 推理节点 → OSD → 输出节点”思想,使用 Python、OpenCV、PyTorch、Transformers DETR 和 FastAPI,方便在 Mac mini / Apple Silicon 上运行。
|
||||||
|
|
||||||
|
## 架构和推理流程
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
flowchart LR
|
||||||
|
A[devicelist.env 摄像头列表] --> B[DeviceManager 默认选择第一个设备]
|
||||||
|
B --> C[能力开放接口获取 RTSP 地址]
|
||||||
|
C --> D[OpenCV VideoCapture 读取视频流]
|
||||||
|
D --> E[StreamWorker 抽帧]
|
||||||
|
E --> F[DetrVehicleDetector 车辆检测]
|
||||||
|
F --> G[OpenCV OSD 画框]
|
||||||
|
G --> H[FastAPI MJPEG /video]
|
||||||
|
G --> I[WebSocket /ws/detections]
|
||||||
|
H --> J[浏览器动态打标画面]
|
||||||
|
I --> J
|
||||||
|
```
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
flowchart TD
|
||||||
|
A[读取一帧 BGR 图像] --> B{是否达到 FRAME_SKIP}
|
||||||
|
B -- 否 --> F[复用上一轮检测结果]
|
||||||
|
B -- 是 --> C[转换为 RGB]
|
||||||
|
C --> D[Transformers ImageProcessor 预处理]
|
||||||
|
D --> E[PyTorch DETR 推理]
|
||||||
|
E --> G[按 DETR_CONFIDENCE 过滤车辆类别]
|
||||||
|
F --> H[绘制检测框和标签]
|
||||||
|
G --> H
|
||||||
|
H --> I[JPEG 编码]
|
||||||
|
I --> J[FastAPI 输出到远程浏览器]
|
||||||
|
```
|
||||||
|
|
||||||
|
## 安装
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 -m venv .venv
|
||||||
|
source .venv/bin/activate
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
## 配置
|
||||||
|
|
||||||
|
视频流地址通过环境变量传入,不要把 RTSP/HLS token 写进代码或文档。
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export STREAM_URL='rtsp://你的摄像头或视频流地址'
|
||||||
|
export DETR_CONFIDENCE=0.6
|
||||||
|
export FRAME_SKIP=5
|
||||||
|
```
|
||||||
|
|
||||||
|
可选配置:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export DETR_MODEL='facebook/detr-resnet-50'
|
||||||
|
export JPEG_QUALITY=80
|
||||||
|
export RESIZE_WIDTH=960
|
||||||
|
export VEHICLE_LABELS='car,motorcycle,bus,truck,bicycle'
|
||||||
|
```
|
||||||
|
|
||||||
|
配置说明:
|
||||||
|
|
||||||
|
- `STREAM_URL`:RTSP、HLS 或本地视频文件路径,必填。
|
||||||
|
- `DETR_CONFIDENCE`:检测置信度阈值,默认 `0.6`。
|
||||||
|
- `FRAME_SKIP`:每隔多少帧做一次 DETR 推理,默认 `3`。数值越大,负载越低。
|
||||||
|
- `RESIZE_WIDTH`:可选,设置后会按宽度等比缩小视频帧,降低 Mac mini 压力。
|
||||||
|
- `VEHICLE_LABELS`:需要保留的 COCO 车辆类别。
|
||||||
|
|
||||||
|
## 运行
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uvicorn app.main:app --host 0.0.0.0 --port 8000
|
||||||
|
```
|
||||||
|
|
||||||
|
本机访问:
|
||||||
|
|
||||||
|
```text
|
||||||
|
http://127.0.0.1:8000
|
||||||
|
```
|
||||||
|
|
||||||
|
局域网远端访问:
|
||||||
|
|
||||||
|
```text
|
||||||
|
http://<Mac-IP>:8000
|
||||||
|
```
|
||||||
|
|
||||||
|
## 接口
|
||||||
|
|
||||||
|
- `GET /`:中文浏览器看板。
|
||||||
|
- `GET /video`:MJPEG 动态打标视频流。
|
||||||
|
- `GET /detections`:最近一帧检测结果 JSON。
|
||||||
|
- `GET /status`:运行状态、模型、设备、帧号、FPS。
|
||||||
|
- `WS /ws/detections`:实时推送检测元数据。
|
||||||
|
|
||||||
|
## RTSP 和 HLS 选择
|
||||||
|
|
||||||
|
建议优先使用 RTSP,实时性更好。HLS `.m3u8` 通常更稳定,但会有几秒到十几秒延迟。
|
||||||
|
|
||||||
|
OpenCV 需要带 FFmpeg 支持才能打开很多 RTSP/HLS 地址。如果无法打开远端流,可以先用本地视频文件验证项目,再考虑增加 PyAV 或 ffmpeg-python 输入后端。
|
||||||
|
|
||||||
|
## macOS 远端访问检查
|
||||||
|
|
||||||
|
如果其他机器访问不到页面,检查 macOS 防火墙,并确认服务监听在 `0.0.0.0`。
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl http://127.0.0.1:8000/status
|
||||||
|
curl http://<Mac-IP>:8000/status
|
||||||
|
lsof -i :8000
|
||||||
|
```
|
||||||
|
|
||||||
|
## 性能建议
|
||||||
|
|
||||||
|
DETR 原版模型较重。Mac mini 上建议先这样跑:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export FRAME_SKIP=5
|
||||||
|
export RESIZE_WIDTH=960
|
||||||
|
```
|
||||||
|
|
||||||
|
Apple Silicon 上会优先使用 PyTorch MPS;不可用时自动回退到 CPU。
|
||||||
|
|
||||||
|
## 与 VideoPipe 的对应关系
|
||||||
|
|
||||||
|
| VideoPipe 概念 | 本项目实现 |
|
||||||
|
| --- | --- |
|
||||||
|
| 源节点 | `StreamWorker` 中的 OpenCV `VideoCapture` |
|
||||||
|
| 推理节点 | `DetrVehicleDetector` |
|
||||||
|
| OSD 节点 | `StreamWorker._draw()` |
|
||||||
|
| 输出节点 | FastAPI `/video`、`/detections`、`/ws/detections` |
|
||||||
|
| 分析看板 | 浏览器中文 Dashboard |
|
||||||
0
app/__init__.py
Normal file
0
app/__init__.py
Normal file
123
app/capacity_api.py
Normal file
123
app/capacity_api.py
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import urllib.parse
|
||||||
|
import urllib.request
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Token:
|
||||||
|
access_token: str
|
||||||
|
expires_at: float
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StreamUrlResult:
|
||||||
|
url: str
|
||||||
|
timings: dict[str, float]
|
||||||
|
|
||||||
|
|
||||||
|
class CapacityApiClient:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_url: str,
|
||||||
|
app_id: str,
|
||||||
|
app_secret: str,
|
||||||
|
account: str,
|
||||||
|
method: str,
|
||||||
|
timeout: int = 20,
|
||||||
|
):
|
||||||
|
self.base_url = base_url.rstrip("/")
|
||||||
|
self.app_id = app_id
|
||||||
|
self.app_secret = app_secret
|
||||||
|
self.account = account
|
||||||
|
self.method = method
|
||||||
|
self.timeout = timeout
|
||||||
|
self.token: Token | None = None
|
||||||
|
|
||||||
|
def get_stream_url(self, device_num: str) -> str:
|
||||||
|
return self.get_stream_url_details(device_num).url
|
||||||
|
|
||||||
|
def get_stream_url_details(self, device_num: str) -> StreamUrlResult:
|
||||||
|
timings: dict[str, float] = {}
|
||||||
|
started = time.monotonic()
|
||||||
|
access_token = self._get_access_token(timings)
|
||||||
|
timings["token_ms"] = round((time.monotonic() - started) * 1000, 2)
|
||||||
|
started = time.monotonic()
|
||||||
|
business_params = {
|
||||||
|
"account": self.account,
|
||||||
|
"deviceNum": device_num,
|
||||||
|
"isSubStream": 0,
|
||||||
|
"networkType": 1,
|
||||||
|
"urlType": 1,
|
||||||
|
}
|
||||||
|
# 接口文档要求业务参数整体放进 params JSON 字符串后再参与签名。
|
||||||
|
params = {
|
||||||
|
"accessToken": access_token,
|
||||||
|
"appId": self.app_id,
|
||||||
|
"method": self.method,
|
||||||
|
"params": json.dumps(business_params, ensure_ascii=False, separators=(",", ":")),
|
||||||
|
"timestamp": self._timestamp(),
|
||||||
|
"v": "1.0.0",
|
||||||
|
}
|
||||||
|
params["sign"] = self._sign(params)
|
||||||
|
timings["sign_ms"] = round((time.monotonic() - started) * 1000, 2)
|
||||||
|
started = time.monotonic()
|
||||||
|
data = self._get_json(f"{self.base_url}/rest", params)
|
||||||
|
timings["stream_url_ms"] = round((time.monotonic() - started) * 1000, 2)
|
||||||
|
if data.get("errorCode") != "0":
|
||||||
|
raise RuntimeError(data.get("errorMsg") or f"播放地址接口返回错误 {data.get('errorCode')}")
|
||||||
|
|
||||||
|
payload = data.get("data") or {}
|
||||||
|
stream_url = payload.get("rtspUrl") or payload.get("rtspUri")
|
||||||
|
if not stream_url:
|
||||||
|
raise RuntimeError("播放地址接口未返回 RTSP 地址")
|
||||||
|
return StreamUrlResult(stream_url, timings)
|
||||||
|
|
||||||
|
def _get_access_token(self, timings: dict[str, float] | None = None) -> str:
|
||||||
|
if self.token and time.time() < self.token.expires_at - 300:
|
||||||
|
if timings is not None:
|
||||||
|
timings["token_cache"] = 1
|
||||||
|
return self.token.access_token
|
||||||
|
|
||||||
|
data = self._get_json(
|
||||||
|
f"{self.base_url}/oauth/token",
|
||||||
|
{
|
||||||
|
"grantType": "client_credential",
|
||||||
|
"appId": self.app_id,
|
||||||
|
"appSecret": self.app_secret,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if data.get("errorCode") != "0":
|
||||||
|
raise RuntimeError(data.get("errorMsg") or f"获取 accessToken 失败 {data.get('errorCode')}")
|
||||||
|
|
||||||
|
payload = data.get("data") or {}
|
||||||
|
access_token = payload.get("accessToken")
|
||||||
|
if not access_token:
|
||||||
|
raise RuntimeError("token 接口未返回 accessToken")
|
||||||
|
|
||||||
|
self.token = Token(
|
||||||
|
access_token=access_token,
|
||||||
|
expires_at=time.time() + int(payload.get("expiresIn") or 604800),
|
||||||
|
)
|
||||||
|
return access_token
|
||||||
|
|
||||||
|
def _get_json(self, url: str, params: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
query = urllib.parse.urlencode(params)
|
||||||
|
with urllib.request.urlopen(f"{url}?{query}", timeout=self.timeout) as response:
|
||||||
|
body = response.read().decode("utf-8")
|
||||||
|
return json.loads(body)
|
||||||
|
|
||||||
|
def _sign(self, params: dict[str, Any]) -> str:
|
||||||
|
# 签名规则:appSecret + 按 ASCII key 排序后的 key/value + appSecret。
|
||||||
|
raw = self.app_secret + "".join(f"{key}{params[key]}" for key in sorted(params)) + self.app_secret
|
||||||
|
return hashlib.md5(raw.encode("utf-8")).hexdigest().upper()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _timestamp() -> str:
|
||||||
|
return datetime.now(timezone(timedelta(hours=8))).strftime("%Y-%m-%d %H:%M:%S")
|
||||||
66
app/config.py
Normal file
66
app/config.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Settings:
|
||||||
|
stream_url: str
|
||||||
|
api_base_url: str
|
||||||
|
app_id: str
|
||||||
|
app_secret: str
|
||||||
|
device_list_path: str
|
||||||
|
device_account: str
|
||||||
|
stream_method: str
|
||||||
|
detr_model: str
|
||||||
|
confidence: float
|
||||||
|
frame_skip: int
|
||||||
|
jpeg_quality: int
|
||||||
|
resize_width: int | None
|
||||||
|
vehicle_labels: set[str]
|
||||||
|
|
||||||
|
|
||||||
|
def _optional_int(name: str) -> int | None:
|
||||||
|
value = os.getenv(name, "").strip()
|
||||||
|
if not value:
|
||||||
|
return None
|
||||||
|
parsed = int(value)
|
||||||
|
return parsed if parsed > 0 else None
|
||||||
|
|
||||||
|
|
||||||
|
def load_settings() -> Settings:
|
||||||
|
stream_url = os.getenv("STREAM_URL", "").strip()
|
||||||
|
|
||||||
|
vehicle_labels = {
|
||||||
|
item.strip()
|
||||||
|
for item in os.getenv("VEHICLE_LABELS", "car,motorcycle,bus,truck,bicycle").split(",")
|
||||||
|
if item.strip()
|
||||||
|
}
|
||||||
|
|
||||||
|
return Settings(
|
||||||
|
stream_url=stream_url,
|
||||||
|
api_base_url=os.getenv("API_BASE_URL", "https://apicapacity.51iwifi.com"),
|
||||||
|
app_id=os.getenv("APP_ID", ""),
|
||||||
|
app_secret=os.getenv("APP_SECRET", ""),
|
||||||
|
device_list_path=os.getenv("DEVICE_LIST_PATH", "devicelist.env"),
|
||||||
|
device_account=os.getenv("DEVICE_ACCOUNT", "21cn"),
|
||||||
|
stream_method=os.getenv("STREAM_METHOD", "capacity.geye.device.devUrl.get"),
|
||||||
|
detr_model=os.getenv("DETR_MODEL", "facebook/detr-resnet-50"),
|
||||||
|
confidence=float(os.getenv("DETR_CONFIDENCE", "0.6")),
|
||||||
|
frame_skip=max(1, int(os.getenv("FRAME_SKIP", "3"))),
|
||||||
|
jpeg_quality=min(100, max(1, int(os.getenv("JPEG_QUALITY", "80")))),
|
||||||
|
resize_width=_optional_int("RESIZE_WIDTH"),
|
||||||
|
vehicle_labels=vehicle_labels,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def mask_url(url: str) -> str:
|
||||||
|
if "token=" not in url:
|
||||||
|
return url
|
||||||
|
prefix, _ = url.split("token=", 1)
|
||||||
|
return f"{prefix}token=***"
|
||||||
60
app/detector.py
Normal file
60
app/detector.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from transformers import DetrForObjectDetection, DetrImageProcessor
|
||||||
|
|
||||||
|
|
||||||
|
class DetrVehicleDetector:
|
||||||
|
def __init__(self, model_name: str, confidence: float, vehicle_labels: set[str]):
|
||||||
|
self.confidence = confidence
|
||||||
|
self.vehicle_labels = vehicle_labels
|
||||||
|
self.device = self._select_device()
|
||||||
|
self.processor = DetrImageProcessor.from_pretrained(model_name)
|
||||||
|
self.model = DetrForObjectDetection.from_pretrained(model_name)
|
||||||
|
self.model.to(self.device)
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _select_device() -> torch.device:
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
return torch.device("mps")
|
||||||
|
return torch.device("cpu")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device_name(self) -> str:
|
||||||
|
return self.device.type
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def detect(self, frame_rgb: Any) -> list[dict[str, Any]]:
|
||||||
|
image = Image.fromarray(frame_rgb)
|
||||||
|
inputs = self.processor(images=image, return_tensors="pt")
|
||||||
|
inputs = {key: value.to(self.device) for key, value in inputs.items()}
|
||||||
|
|
||||||
|
outputs = self.model(**inputs)
|
||||||
|
# DETR 后处理需要原图尺寸,PIL size 是 (宽, 高),这里转成 (高, 宽)。
|
||||||
|
target_sizes = torch.tensor([image.size[::-1]], device=self.device)
|
||||||
|
results = self.processor.post_process_object_detection(
|
||||||
|
outputs,
|
||||||
|
target_sizes=target_sizes,
|
||||||
|
threshold=self.confidence,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
detections = []
|
||||||
|
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
|
||||||
|
label_name = self.model.config.id2label[label.item()]
|
||||||
|
if label_name not in self.vehicle_labels:
|
||||||
|
continue
|
||||||
|
|
||||||
|
x1, y1, x2, y2 = box.detach().cpu().numpy().astype(int).tolist()
|
||||||
|
detections.append(
|
||||||
|
{
|
||||||
|
"label": label_name,
|
||||||
|
"score": round(float(score.detach().cpu()), 4),
|
||||||
|
"box": [x1, y1, x2, y2],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return detections
|
||||||
91
app/device_manager.py
Normal file
91
app/device_manager.py
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.capacity_api import CapacityApiClient
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Device:
|
||||||
|
name: str
|
||||||
|
device_num: str
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceManager:
|
||||||
|
def __init__(self, path: str, api_client: CapacityApiClient, fallback_url: str = ""):
|
||||||
|
self.devices = self._load_devices(path)
|
||||||
|
self.api_client = api_client
|
||||||
|
self.fallback_url = fallback_url
|
||||||
|
self.lock = threading.Lock()
|
||||||
|
self.current_device_num = self.devices[0].device_num if self.devices else ""
|
||||||
|
self.current_url = fallback_url
|
||||||
|
self.timings: dict[str, float] = {}
|
||||||
|
self.updated_at = 0.0
|
||||||
|
self.version = 0
|
||||||
|
|
||||||
|
def set_current_device(self, device_num: str) -> int:
|
||||||
|
if device_num not in {device.device_num for device in self.devices}:
|
||||||
|
raise ValueError("设备不在 devicelist.env 中")
|
||||||
|
with self.lock:
|
||||||
|
self.current_device_num = device_num
|
||||||
|
self.current_url = ""
|
||||||
|
self.timings = {}
|
||||||
|
self.updated_at = time.time()
|
||||||
|
self.version += 1
|
||||||
|
return self.version
|
||||||
|
|
||||||
|
def resolve_stream_url(self) -> str:
|
||||||
|
with self.lock:
|
||||||
|
device_num = self.current_device_num
|
||||||
|
version = self.version
|
||||||
|
if not device_num:
|
||||||
|
if self.fallback_url:
|
||||||
|
return self.fallback_url
|
||||||
|
raise RuntimeError("devicelist.env 中没有可用设备号")
|
||||||
|
|
||||||
|
result = self.api_client.get_stream_url_details(device_num)
|
||||||
|
with self.lock:
|
||||||
|
# 避免旧摄像头的慢接口响应覆盖用户刚切换的新选择。
|
||||||
|
if version != self.version or device_num != self.current_device_num:
|
||||||
|
return self.current_url
|
||||||
|
self.current_url = result.url
|
||||||
|
self.timings = dict(result.timings)
|
||||||
|
self.updated_at = time.time()
|
||||||
|
return result.url
|
||||||
|
|
||||||
|
def get_snapshot(self) -> dict[str, Any]:
|
||||||
|
with self.lock:
|
||||||
|
return {
|
||||||
|
"devices": [device.__dict__ for device in self.devices],
|
||||||
|
"current_device_num": self.current_device_num,
|
||||||
|
"current_url": self.current_url,
|
||||||
|
"source_timings": dict(self.timings),
|
||||||
|
"source_updated_at": self.updated_at,
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _load_devices(path: str) -> list[Device]:
|
||||||
|
devices: list[Device] = []
|
||||||
|
file_path = Path(path)
|
||||||
|
if not file_path.exists():
|
||||||
|
return devices
|
||||||
|
|
||||||
|
for line in file_path.read_text(encoding="utf-8").splitlines():
|
||||||
|
stripped = line.strip()
|
||||||
|
if not stripped or stripped.startswith("#"):
|
||||||
|
continue
|
||||||
|
if "=" in stripped:
|
||||||
|
name, value = stripped.split("=", 1)
|
||||||
|
values = [item.strip() for item in value.split(",") if item.strip()]
|
||||||
|
display_name = name.strip()
|
||||||
|
else:
|
||||||
|
values = [stripped]
|
||||||
|
display_name = "摄像头"
|
||||||
|
for device_num in values:
|
||||||
|
name = display_name if len(values) == 1 else f"摄像头 {len(devices) + 1}"
|
||||||
|
devices.append(Device(name, device_num))
|
||||||
|
return devices
|
||||||
163
app/main.py
Normal file
163
app/main.py
Normal file
@@ -0,0 +1,163 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from fastapi import FastAPI, HTTPException, Request, WebSocket, WebSocketDisconnect
|
||||||
|
from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse
|
||||||
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
from fastapi.templating import Jinja2Templates
|
||||||
|
|
||||||
|
from app.capacity_api import CapacityApiClient
|
||||||
|
from app.config import load_settings, mask_url
|
||||||
|
from app.detector import DetrVehicleDetector
|
||||||
|
from app.device_manager import DeviceManager
|
||||||
|
from app.stream_worker import StreamWorker
|
||||||
|
|
||||||
|
settings = load_settings()
|
||||||
|
api_client = CapacityApiClient(
|
||||||
|
base_url=settings.api_base_url,
|
||||||
|
app_id=settings.app_id,
|
||||||
|
app_secret=settings.app_secret,
|
||||||
|
account=settings.device_account,
|
||||||
|
method=settings.stream_method,
|
||||||
|
)
|
||||||
|
detector = DetrVehicleDetector(
|
||||||
|
model_name=settings.detr_model,
|
||||||
|
confidence=settings.confidence,
|
||||||
|
vehicle_labels=settings.vehicle_labels,
|
||||||
|
)
|
||||||
|
device_manager = DeviceManager(
|
||||||
|
path=settings.device_list_path,
|
||||||
|
api_client=api_client,
|
||||||
|
fallback_url=settings.stream_url,
|
||||||
|
)
|
||||||
|
def resolve_stream_url() -> str:
|
||||||
|
return device_manager.resolve_stream_url()
|
||||||
|
|
||||||
|
|
||||||
|
worker = StreamWorker(
|
||||||
|
stream_url=resolve_stream_url,
|
||||||
|
detector=detector,
|
||||||
|
frame_skip=settings.frame_skip,
|
||||||
|
jpeg_quality=settings.jpeg_quality,
|
||||||
|
resize_width=settings.resize_width,
|
||||||
|
)
|
||||||
|
|
||||||
|
app = FastAPI(title="DETR 动态打标")
|
||||||
|
app.mount("/static", StaticFiles(directory="app/static"), name="static")
|
||||||
|
templates = Jinja2Templates(directory="app/templates")
|
||||||
|
|
||||||
|
|
||||||
|
def display_model_name(model_name: str) -> str:
|
||||||
|
return model_name.rsplit("/", 1)[-1]
|
||||||
|
|
||||||
|
|
||||||
|
@app.on_event("startup")
|
||||||
|
def startup() -> None:
|
||||||
|
worker.start()
|
||||||
|
|
||||||
|
|
||||||
|
@app.on_event("shutdown")
|
||||||
|
def shutdown() -> None:
|
||||||
|
worker.stop()
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/", response_class=HTMLResponse)
|
||||||
|
def index(request: Request) -> HTMLResponse:
|
||||||
|
return templates.TemplateResponse(
|
||||||
|
"index.html",
|
||||||
|
{
|
||||||
|
"request": request,
|
||||||
|
"model": display_model_name(settings.detr_model),
|
||||||
|
"device": detector.device_name,
|
||||||
|
"stream_url": f"设备号:{device_manager.get_snapshot()['current_device_num']}",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/video")
|
||||||
|
def video() -> StreamingResponse:
|
||||||
|
async def generate():
|
||||||
|
while True:
|
||||||
|
frame = worker.get_jpeg()
|
||||||
|
if frame is None:
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
yield b"--frame\r\nContent-Type: image/jpeg\r\n\r\n" + frame + b"\r\n"
|
||||||
|
await asyncio.sleep(0.03)
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
generate(),
|
||||||
|
media_type="multipart/x-mixed-replace; boundary=frame",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/detections")
|
||||||
|
def detections() -> JSONResponse:
|
||||||
|
snapshot = worker.get_snapshot()
|
||||||
|
return JSONResponse(
|
||||||
|
{
|
||||||
|
"frame_id": snapshot["frame_id"],
|
||||||
|
"updated_at": snapshot["updated_at"],
|
||||||
|
"detections": snapshot["detections"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/status")
|
||||||
|
def status() -> JSONResponse:
|
||||||
|
snapshot = worker.get_snapshot()
|
||||||
|
device_snapshot = device_manager.get_snapshot()
|
||||||
|
timings = dict(device_snapshot["source_timings"])
|
||||||
|
# 合并取流地址和 OpenCV 读流耗时,前端按同一个 timings 对象展示。
|
||||||
|
timings.update(snapshot["timings"])
|
||||||
|
return JSONResponse(
|
||||||
|
{
|
||||||
|
"running": snapshot["running"],
|
||||||
|
"connected": snapshot["connected"],
|
||||||
|
"frame_id": snapshot["frame_id"],
|
||||||
|
"updated_at": snapshot["updated_at"],
|
||||||
|
"fps": snapshot["fps"],
|
||||||
|
"error": snapshot["error"],
|
||||||
|
"source": mask_url(device_snapshot["current_url"]) if device_snapshot["current_url"] else "等待获取播放地址",
|
||||||
|
"model": display_model_name(settings.detr_model),
|
||||||
|
"device": detector.device_name,
|
||||||
|
"frame_skip": settings.frame_skip,
|
||||||
|
"confidence": settings.confidence,
|
||||||
|
"devices": device_snapshot["devices"],
|
||||||
|
"current_device_num": device_snapshot["current_device_num"],
|
||||||
|
"timings": timings,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/devices/{device_num}")
|
||||||
|
def switch_device(device_num: str) -> JSONResponse:
|
||||||
|
try:
|
||||||
|
device_manager.set_current_device(device_num)
|
||||||
|
except ValueError as exc:
|
||||||
|
raise HTTPException(status_code=404, detail=str(exc)) from exc
|
||||||
|
worker.reconnect()
|
||||||
|
return JSONResponse({"current_device_num": device_num})
|
||||||
|
|
||||||
|
|
||||||
|
@app.websocket("/ws/detections")
|
||||||
|
async def websocket_detections(websocket: WebSocket) -> None:
|
||||||
|
await websocket.accept()
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
data = worker.get_snapshot()
|
||||||
|
device_snapshot = device_manager.get_snapshot()
|
||||||
|
data.update(
|
||||||
|
{
|
||||||
|
"devices": device_snapshot["devices"],
|
||||||
|
"current_device_num": device_snapshot["current_device_num"],
|
||||||
|
"source": mask_url(device_snapshot["current_url"]) if device_snapshot["current_url"] else "等待获取播放地址",
|
||||||
|
"timings": {**device_snapshot["source_timings"], **data["timings"]},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
await websocket.send_json(data)
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
return
|
||||||
142
app/static/app.js
Normal file
142
app/static/app.js
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
const connection = document.querySelector("#connection");
|
||||||
|
const detectionsEl = document.querySelector("#detections");
|
||||||
|
const frameIdEl = document.querySelector("#frame-id");
|
||||||
|
const fpsEl = document.querySelector("#fps");
|
||||||
|
const errorEl = document.querySelector("#error");
|
||||||
|
const sourceEl = document.querySelector("#source");
|
||||||
|
const deviceSelect = document.querySelector("#device-select");
|
||||||
|
const timingTokenEl = document.querySelector("#timing-token");
|
||||||
|
const timingSignEl = document.querySelector("#timing-sign");
|
||||||
|
const timingUrlEl = document.querySelector("#timing-url");
|
||||||
|
const timingOpenEl = document.querySelector("#timing-open");
|
||||||
|
const timingFrameEl = document.querySelector("#timing-frame");
|
||||||
|
|
||||||
|
let selectedDevice = "";
|
||||||
|
let pendingDevice = "";
|
||||||
|
let devicesSignature = "";
|
||||||
|
|
||||||
|
function setConnection(online, text) {
|
||||||
|
connection.textContent = text;
|
||||||
|
connection.classList.toggle("online", online);
|
||||||
|
connection.classList.toggle("offline", !online);
|
||||||
|
}
|
||||||
|
|
||||||
|
function formatMs(value) {
|
||||||
|
if (value === undefined || value === null || value === 0) {
|
||||||
|
return "-";
|
||||||
|
}
|
||||||
|
return `${Number(value).toFixed(2)} ms`;
|
||||||
|
}
|
||||||
|
|
||||||
|
function renderDevices(devices, currentDeviceNum) {
|
||||||
|
if (!devices.length) {
|
||||||
|
deviceSelect.innerHTML = '<option value="">未配置摄像头</option>';
|
||||||
|
deviceSelect.disabled = true;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const displayDevice = pendingDevice || currentDeviceNum;
|
||||||
|
const nextSignature = `${displayDevice}|${devices.map((device) => device.device_num).join(",")}`;
|
||||||
|
if (nextSignature === devicesSignature) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
selectedDevice = displayDevice;
|
||||||
|
devicesSignature = nextSignature;
|
||||||
|
deviceSelect.innerHTML = devices
|
||||||
|
.map((device) => {
|
||||||
|
const selected = device.device_num === displayDevice ? "selected" : "";
|
||||||
|
return `<option value="${device.device_num}" ${selected}>${device.name} · ${device.device_num}</option>`;
|
||||||
|
})
|
||||||
|
.join("");
|
||||||
|
}
|
||||||
|
|
||||||
|
function renderTimings(timings) {
|
||||||
|
timingTokenEl.textContent = timings?.token_cache ? "缓存" : formatMs(timings?.token_ms);
|
||||||
|
timingSignEl.textContent = formatMs(timings?.sign_ms);
|
||||||
|
timingUrlEl.textContent = formatMs(timings?.stream_url_ms);
|
||||||
|
timingOpenEl.textContent = formatMs(timings?.open_ms);
|
||||||
|
timingFrameEl.textContent = formatMs(timings?.first_frame_ms);
|
||||||
|
}
|
||||||
|
|
||||||
|
function renderDetections(detections) {
|
||||||
|
if (!detections.length) {
|
||||||
|
detectionsEl.className = "detections empty";
|
||||||
|
detectionsEl.textContent = "暂无目标";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
detectionsEl.className = "detections";
|
||||||
|
detectionsEl.innerHTML = detections
|
||||||
|
.map((det) => {
|
||||||
|
const score = `${(det.score * 100).toFixed(1)}%`;
|
||||||
|
const box = det.box.join(", ");
|
||||||
|
return `
|
||||||
|
<div class="det-item">
|
||||||
|
<div class="det-title">
|
||||||
|
<span>${det.label}</span>
|
||||||
|
<span>${score}</span>
|
||||||
|
</div>
|
||||||
|
<div class="det-box">box: [${box}]</div>
|
||||||
|
</div>
|
||||||
|
`;
|
||||||
|
})
|
||||||
|
.join("");
|
||||||
|
}
|
||||||
|
|
||||||
|
async function switchDevice(deviceNum) {
|
||||||
|
pendingDevice = deviceNum;
|
||||||
|
devicesSignature = "";
|
||||||
|
setConnection(false, "切换中");
|
||||||
|
const response = await fetch(`/devices/${encodeURIComponent(deviceNum)}`, { method: "POST" });
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error("切换摄像头失败");
|
||||||
|
}
|
||||||
|
const video = document.querySelector("#video");
|
||||||
|
video.src = `/video?t=${Date.now()}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
function connectWebSocket() {
|
||||||
|
const protocol = window.location.protocol === "https:" ? "wss" : "ws";
|
||||||
|
const ws = new WebSocket(`${protocol}://${window.location.host}/ws/detections`);
|
||||||
|
|
||||||
|
ws.addEventListener("open", () => setConnection(true, "已连接"));
|
||||||
|
|
||||||
|
ws.addEventListener("message", (event) => {
|
||||||
|
const data = JSON.parse(event.data);
|
||||||
|
frameIdEl.textContent = data.frame_id ?? "-";
|
||||||
|
fpsEl.textContent = data.fps ?? "-";
|
||||||
|
errorEl.textContent = data.error || (data.connected ? "正常" : "未连接");
|
||||||
|
sourceEl.textContent = data.source || "-";
|
||||||
|
setConnection(Boolean(data.connected), data.connected ? "已连接" : "重连中");
|
||||||
|
if (pendingDevice && data.current_device_num === pendingDevice) {
|
||||||
|
pendingDevice = "";
|
||||||
|
deviceSelect.disabled = false;
|
||||||
|
}
|
||||||
|
renderDevices(data.devices || [], data.current_device_num || "");
|
||||||
|
renderTimings(data.timings || {});
|
||||||
|
renderDetections(data.detections || []);
|
||||||
|
});
|
||||||
|
|
||||||
|
ws.addEventListener("close", () => {
|
||||||
|
setConnection(false, "已断开");
|
||||||
|
setTimeout(connectWebSocket, 1500);
|
||||||
|
});
|
||||||
|
|
||||||
|
ws.addEventListener("error", () => {
|
||||||
|
setConnection(false, "连接错误");
|
||||||
|
ws.close();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
deviceSelect.addEventListener("change", (event) => {
|
||||||
|
selectedDevice = event.target.value;
|
||||||
|
switchDevice(event.target.value).catch(() => {
|
||||||
|
pendingDevice = "";
|
||||||
|
devicesSignature = "";
|
||||||
|
setConnection(false, "切换失败");
|
||||||
|
deviceSelect.disabled = false;
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
connectWebSocket();
|
||||||
209
app/static/style.css
Normal file
209
app/static/style.css
Normal file
@@ -0,0 +1,209 @@
|
|||||||
|
:root {
|
||||||
|
color-scheme: dark;
|
||||||
|
--bg: #0c1017;
|
||||||
|
--panel: #151b26;
|
||||||
|
--panel-2: #101722;
|
||||||
|
--text: #eef4ff;
|
||||||
|
--muted: #8f9db3;
|
||||||
|
--line: #273246;
|
||||||
|
--green: #2ee887;
|
||||||
|
--yellow: #f7c948;
|
||||||
|
--red: #ff6b6b;
|
||||||
|
}
|
||||||
|
|
||||||
|
* {
|
||||||
|
box-sizing: border-box;
|
||||||
|
}
|
||||||
|
|
||||||
|
body {
|
||||||
|
margin: 0;
|
||||||
|
min-height: 100vh;
|
||||||
|
background: radial-gradient(circle at top left, #182235, var(--bg) 45%);
|
||||||
|
color: var(--text);
|
||||||
|
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif;
|
||||||
|
}
|
||||||
|
|
||||||
|
.topbar {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: space-between;
|
||||||
|
gap: 24px;
|
||||||
|
padding: 22px 28px;
|
||||||
|
border-bottom: 1px solid var(--line);
|
||||||
|
background: rgba(12, 16, 23, 0.82);
|
||||||
|
backdrop-filter: blur(14px);
|
||||||
|
}
|
||||||
|
|
||||||
|
h1,
|
||||||
|
h2,
|
||||||
|
p {
|
||||||
|
margin: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
h1 {
|
||||||
|
font-size: 24px;
|
||||||
|
letter-spacing: 0.02em;
|
||||||
|
}
|
||||||
|
|
||||||
|
h2 {
|
||||||
|
margin-bottom: 14px;
|
||||||
|
font-size: 18px;
|
||||||
|
}
|
||||||
|
|
||||||
|
p {
|
||||||
|
margin-top: 8px;
|
||||||
|
color: var(--muted);
|
||||||
|
}
|
||||||
|
|
||||||
|
.badge {
|
||||||
|
min-width: 86px;
|
||||||
|
padding: 8px 12px;
|
||||||
|
border: 1px solid var(--line);
|
||||||
|
border-radius: 999px;
|
||||||
|
color: var(--yellow);
|
||||||
|
text-align: center;
|
||||||
|
background: var(--panel);
|
||||||
|
}
|
||||||
|
|
||||||
|
.badge.online {
|
||||||
|
color: var(--green);
|
||||||
|
}
|
||||||
|
|
||||||
|
.badge.offline {
|
||||||
|
color: var(--red);
|
||||||
|
}
|
||||||
|
|
||||||
|
.layout {
|
||||||
|
display: grid;
|
||||||
|
grid-template-columns: minmax(0, 1fr) 360px;
|
||||||
|
gap: 18px;
|
||||||
|
padding: 18px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.video-card,
|
||||||
|
.side-card {
|
||||||
|
border: 1px solid var(--line);
|
||||||
|
border-radius: 18px;
|
||||||
|
background: rgba(21, 27, 38, 0.9);
|
||||||
|
box-shadow: 0 18px 40px rgba(0, 0, 0, 0.28);
|
||||||
|
}
|
||||||
|
|
||||||
|
.video-card {
|
||||||
|
overflow: hidden;
|
||||||
|
}
|
||||||
|
|
||||||
|
.pipeline {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: 10px;
|
||||||
|
padding: 14px;
|
||||||
|
border-bottom: 1px solid var(--line);
|
||||||
|
overflow-x: auto;
|
||||||
|
}
|
||||||
|
|
||||||
|
.stage {
|
||||||
|
flex: 0 0 auto;
|
||||||
|
padding: 9px 12px;
|
||||||
|
border: 1px solid var(--line);
|
||||||
|
border-radius: 10px;
|
||||||
|
color: var(--muted);
|
||||||
|
background: var(--panel-2);
|
||||||
|
}
|
||||||
|
|
||||||
|
.stage.active {
|
||||||
|
border-color: rgba(46, 232, 135, 0.5);
|
||||||
|
color: var(--green);
|
||||||
|
}
|
||||||
|
|
||||||
|
.arrow {
|
||||||
|
color: var(--muted);
|
||||||
|
}
|
||||||
|
|
||||||
|
.video-wrap {
|
||||||
|
display: grid;
|
||||||
|
place-items: center;
|
||||||
|
min-height: 420px;
|
||||||
|
background: #05070b;
|
||||||
|
}
|
||||||
|
|
||||||
|
#video {
|
||||||
|
display: block;
|
||||||
|
width: 100%;
|
||||||
|
height: auto;
|
||||||
|
max-height: calc(100vh - 190px);
|
||||||
|
object-fit: contain;
|
||||||
|
}
|
||||||
|
|
||||||
|
.side-card {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 22px;
|
||||||
|
padding: 18px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.status-grid {
|
||||||
|
display: grid;
|
||||||
|
grid-template-columns: 72px minmax(0, 1fr);
|
||||||
|
gap: 10px 12px;
|
||||||
|
margin: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.status-grid dt {
|
||||||
|
color: var(--muted);
|
||||||
|
}
|
||||||
|
|
||||||
|
.status-grid dd {
|
||||||
|
margin: 0;
|
||||||
|
word-break: break-all;
|
||||||
|
}
|
||||||
|
|
||||||
|
.device-select {
|
||||||
|
width: 100%;
|
||||||
|
min-height: 34px;
|
||||||
|
border: 1px solid var(--line);
|
||||||
|
border-radius: 8px;
|
||||||
|
color: var(--text);
|
||||||
|
background: var(--panel-2);
|
||||||
|
}
|
||||||
|
|
||||||
|
.detections {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 10px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.detections.empty {
|
||||||
|
color: var(--muted);
|
||||||
|
}
|
||||||
|
|
||||||
|
.det-item {
|
||||||
|
padding: 12px;
|
||||||
|
border: 1px solid var(--line);
|
||||||
|
border-radius: 12px;
|
||||||
|
background: var(--panel-2);
|
||||||
|
}
|
||||||
|
|
||||||
|
.det-title {
|
||||||
|
display: flex;
|
||||||
|
justify-content: space-between;
|
||||||
|
margin-bottom: 8px;
|
||||||
|
color: var(--green);
|
||||||
|
font-weight: 700;
|
||||||
|
}
|
||||||
|
|
||||||
|
.det-box {
|
||||||
|
color: var(--muted);
|
||||||
|
font-family: ui-monospace, SFMono-Regular, Menlo, monospace;
|
||||||
|
font-size: 12px;
|
||||||
|
}
|
||||||
|
|
||||||
|
@media (max-width: 980px) {
|
||||||
|
.layout {
|
||||||
|
grid-template-columns: 1fr;
|
||||||
|
}
|
||||||
|
|
||||||
|
.topbar {
|
||||||
|
align-items: flex-start;
|
||||||
|
flex-direction: column;
|
||||||
|
}
|
||||||
|
}
|
||||||
221
app/stream_worker.py
Normal file
221
app/stream_worker.py
Normal file
@@ -0,0 +1,221 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from typing import Any, Callable
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
|
||||||
|
class StreamWorker:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
stream_url: str | Callable[[], str],
|
||||||
|
detector: Any,
|
||||||
|
frame_skip: int = 3,
|
||||||
|
jpeg_quality: int = 80,
|
||||||
|
resize_width: int | None = None,
|
||||||
|
):
|
||||||
|
self.stream_url = stream_url
|
||||||
|
self.detector = detector
|
||||||
|
self.frame_skip = max(1, frame_skip)
|
||||||
|
self.jpeg_quality = jpeg_quality
|
||||||
|
self.resize_width = resize_width
|
||||||
|
|
||||||
|
self.lock = threading.Lock()
|
||||||
|
self.latest_jpeg: bytes | None = None
|
||||||
|
self.latest_detections: list[dict[str, Any]] = []
|
||||||
|
self.frame_id = 0
|
||||||
|
self.updated_at = 0.0
|
||||||
|
self.running = False
|
||||||
|
self.connected = False
|
||||||
|
self.error = "尚未启动"
|
||||||
|
self.fps = 0.0
|
||||||
|
self.thread: threading.Thread | None = None
|
||||||
|
self.reconnect_requested = False
|
||||||
|
self.reconnect_version = 0
|
||||||
|
self.resolve_ms = 0.0
|
||||||
|
self.open_ms = 0.0
|
||||||
|
self.first_frame_ms = 0.0
|
||||||
|
|
||||||
|
def start(self) -> None:
|
||||||
|
if self.running:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.running = True
|
||||||
|
self.thread = threading.Thread(target=self._run, daemon=True)
|
||||||
|
self.thread.start()
|
||||||
|
|
||||||
|
def stop(self) -> None:
|
||||||
|
self.running = False
|
||||||
|
if self.thread and self.thread.is_alive():
|
||||||
|
self.thread.join(timeout=2)
|
||||||
|
|
||||||
|
def reconnect(self) -> None:
|
||||||
|
with self.lock:
|
||||||
|
self.latest_jpeg = None
|
||||||
|
self.latest_detections = []
|
||||||
|
self.frame_id = 0
|
||||||
|
self.fps = 0.0
|
||||||
|
self.reconnect_requested = True
|
||||||
|
self.reconnect_version += 1
|
||||||
|
self.connected = False
|
||||||
|
self.error = "正在切换视频源"
|
||||||
|
self.resolve_ms = 0.0
|
||||||
|
self.open_ms = 0.0
|
||||||
|
self.first_frame_ms = 0.0
|
||||||
|
|
||||||
|
def get_jpeg(self) -> bytes | None:
|
||||||
|
with self.lock:
|
||||||
|
return self.latest_jpeg
|
||||||
|
|
||||||
|
def get_snapshot(self) -> dict[str, Any]:
|
||||||
|
with self.lock:
|
||||||
|
return {
|
||||||
|
"frame_id": self.frame_id,
|
||||||
|
"updated_at": self.updated_at,
|
||||||
|
"detections": list(self.latest_detections),
|
||||||
|
"running": self.running,
|
||||||
|
"connected": self.connected,
|
||||||
|
"error": self.error,
|
||||||
|
"fps": round(self.fps, 2),
|
||||||
|
"timings": {
|
||||||
|
"resolve_ms": self.resolve_ms,
|
||||||
|
"open_ms": self.open_ms,
|
||||||
|
"first_frame_ms": self.first_frame_ms,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def _run(self) -> None:
|
||||||
|
cap: cv2.VideoCapture | None = None
|
||||||
|
last_detections: list[dict[str, Any]] = []
|
||||||
|
fps_window_start = time.monotonic()
|
||||||
|
fps_frames = 0
|
||||||
|
|
||||||
|
while self.running:
|
||||||
|
with self.lock:
|
||||||
|
should_reconnect = self.reconnect_requested
|
||||||
|
run_version = self.reconnect_version
|
||||||
|
self.reconnect_requested = False
|
||||||
|
if should_reconnect:
|
||||||
|
# 切换摄像头时必须释放旧连接,否则 OpenCV 会继续阻塞读旧流。
|
||||||
|
if cap is not None:
|
||||||
|
cap.release()
|
||||||
|
cap = None
|
||||||
|
|
||||||
|
if cap is None or not cap.isOpened():
|
||||||
|
started = time.monotonic()
|
||||||
|
stream_url = self.stream_url() if callable(self.stream_url) else self.stream_url
|
||||||
|
resolve_ms = round((time.monotonic() - started) * 1000, 2)
|
||||||
|
started = time.monotonic()
|
||||||
|
cap = cv2.VideoCapture(stream_url)
|
||||||
|
open_ms = round((time.monotonic() - started) * 1000, 2)
|
||||||
|
with self.lock:
|
||||||
|
self.open_ms = open_ms
|
||||||
|
self.resolve_ms = resolve_ms
|
||||||
|
self.first_frame_ms = 0.0
|
||||||
|
if not cap.isOpened():
|
||||||
|
self._set_connection_state(False, "无法打开视频流,2 秒后重试")
|
||||||
|
cap.release()
|
||||||
|
cap = None
|
||||||
|
time.sleep(2)
|
||||||
|
continue
|
||||||
|
self._set_connection_state(True, "已连接")
|
||||||
|
|
||||||
|
started = time.monotonic()
|
||||||
|
ok, frame = cap.read()
|
||||||
|
with self.lock:
|
||||||
|
current_version = self.reconnect_version
|
||||||
|
# 丢弃切换期间从旧连接读到的帧,避免前端画面回跳。
|
||||||
|
if current_version != run_version:
|
||||||
|
continue
|
||||||
|
if self.first_frame_ms == 0.0:
|
||||||
|
with self.lock:
|
||||||
|
self.first_frame_ms = round((time.monotonic() - started) * 1000, 2)
|
||||||
|
if not ok:
|
||||||
|
self._set_connection_state(False, "读取视频帧失败,正在重连")
|
||||||
|
cap.release()
|
||||||
|
cap = None
|
||||||
|
time.sleep(2)
|
||||||
|
continue
|
||||||
|
|
||||||
|
frame = self._resize(frame)
|
||||||
|
self.frame_id += 1
|
||||||
|
|
||||||
|
if self.frame_id % self.frame_skip == 0:
|
||||||
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||||
|
last_detections = self.detector.detect(frame_rgb)
|
||||||
|
|
||||||
|
annotated = self._draw(frame, last_detections)
|
||||||
|
ok, jpeg = cv2.imencode(
|
||||||
|
".jpg",
|
||||||
|
annotated,
|
||||||
|
[cv2.IMWRITE_JPEG_QUALITY, self.jpeg_quality],
|
||||||
|
)
|
||||||
|
if not ok:
|
||||||
|
continue
|
||||||
|
|
||||||
|
fps_frames += 1
|
||||||
|
now = time.monotonic()
|
||||||
|
if now - fps_window_start >= 1:
|
||||||
|
fps = fps_frames / (now - fps_window_start)
|
||||||
|
fps_window_start = now
|
||||||
|
fps_frames = 0
|
||||||
|
else:
|
||||||
|
fps = self.fps
|
||||||
|
|
||||||
|
with self.lock:
|
||||||
|
current_version = self.reconnect_version
|
||||||
|
if current_version != run_version:
|
||||||
|
continue
|
||||||
|
|
||||||
|
with self.lock:
|
||||||
|
self.latest_jpeg = jpeg.tobytes()
|
||||||
|
self.latest_detections = list(last_detections)
|
||||||
|
self.updated_at = time.time()
|
||||||
|
self.connected = True
|
||||||
|
self.error = ""
|
||||||
|
self.fps = fps
|
||||||
|
|
||||||
|
if cap is not None:
|
||||||
|
cap.release()
|
||||||
|
self._set_connection_state(False, "已停止")
|
||||||
|
|
||||||
|
def _resize(self, frame: Any) -> Any:
|
||||||
|
if not self.resize_width:
|
||||||
|
return frame
|
||||||
|
|
||||||
|
height, width = frame.shape[:2]
|
||||||
|
if width <= self.resize_width:
|
||||||
|
return frame
|
||||||
|
|
||||||
|
scale = self.resize_width / width
|
||||||
|
return cv2.resize(frame, (self.resize_width, int(height * scale)))
|
||||||
|
|
||||||
|
def _draw(self, frame: Any, detections: list[dict[str, Any]]) -> Any:
|
||||||
|
for detection in detections:
|
||||||
|
x1, y1, x2, y2 = detection["box"]
|
||||||
|
label = detection["label"]
|
||||||
|
score = detection["score"]
|
||||||
|
|
||||||
|
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 220, 80), 2)
|
||||||
|
text = f"{label} {score:.2f}"
|
||||||
|
cv2.rectangle(frame, (x1, max(0, y1 - 26)), (x1 + 150, y1), (0, 220, 80), -1)
|
||||||
|
cv2.putText(
|
||||||
|
frame,
|
||||||
|
text,
|
||||||
|
(x1 + 5, max(18, y1 - 7)),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX,
|
||||||
|
0.55,
|
||||||
|
(0, 0, 0),
|
||||||
|
2,
|
||||||
|
cv2.LINE_AA,
|
||||||
|
)
|
||||||
|
|
||||||
|
return frame
|
||||||
|
|
||||||
|
def _set_connection_state(self, connected: bool, error: str) -> None:
|
||||||
|
with self.lock:
|
||||||
|
self.connected = connected
|
||||||
|
self.error = error
|
||||||
|
self.updated_at = time.time()
|
||||||
82
app/templates/index.html
Normal file
82
app/templates/index.html
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
<!doctype html>
|
||||||
|
<html lang="zh-CN">
|
||||||
|
<head>
|
||||||
|
<meta charset="utf-8" />
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||||
|
<title>DETR 动态打标</title>
|
||||||
|
<link rel="stylesheet" href="/static/style.css" />
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<header class="topbar">
|
||||||
|
<div>
|
||||||
|
<h1>DETR 动态打标</h1>
|
||||||
|
<p>使用 Python、OpenCV、PyTorch、Transformers DETR 和 FastAPI,Mac mini m2 上运行。</p>
|
||||||
|
</div>
|
||||||
|
<div class="badge" id="connection">连接中</div>
|
||||||
|
</header>
|
||||||
|
|
||||||
|
<main class="layout">
|
||||||
|
<section class="video-card">
|
||||||
|
<div class="pipeline">
|
||||||
|
<div class="stage active">源节点</div>
|
||||||
|
<div class="arrow">→</div>
|
||||||
|
<div class="stage active">DETR 推理</div>
|
||||||
|
<div class="arrow">→</div>
|
||||||
|
<div class="stage active">OSD 打标</div>
|
||||||
|
<div class="arrow">→</div>
|
||||||
|
<div class="stage active">FastAPI 输出</div>
|
||||||
|
</div>
|
||||||
|
<div class="video-wrap">
|
||||||
|
<img id="video" src="/video" alt="动态打标视频流" />
|
||||||
|
</div>
|
||||||
|
</section>
|
||||||
|
|
||||||
|
<aside class="side-card">
|
||||||
|
<section>
|
||||||
|
<h2>运行状态</h2>
|
||||||
|
<dl class="status-grid">
|
||||||
|
<dt>摄像头</dt>
|
||||||
|
<dd>
|
||||||
|
<select id="device-select" class="device-select"></select>
|
||||||
|
</dd>
|
||||||
|
<dt>视频源</dt>
|
||||||
|
<dd id="source">{{ stream_url }}</dd>
|
||||||
|
<dt>模型</dt>
|
||||||
|
<dd>{{ model }}</dd>
|
||||||
|
<dt>设备</dt>
|
||||||
|
<dd>{{ device }}</dd>
|
||||||
|
<dt>帧号</dt>
|
||||||
|
<dd id="frame-id">-</dd>
|
||||||
|
<dt>FPS</dt>
|
||||||
|
<dd id="fps">-</dd>
|
||||||
|
<dt>状态</dt>
|
||||||
|
<dd id="error">等待视频帧</dd>
|
||||||
|
</dl>
|
||||||
|
</section>
|
||||||
|
|
||||||
|
<section>
|
||||||
|
<h2>连接耗时</h2>
|
||||||
|
<dl class="status-grid">
|
||||||
|
<dt>Token</dt>
|
||||||
|
<dd id="timing-token">-</dd>
|
||||||
|
<dt>签名</dt>
|
||||||
|
<dd id="timing-sign">-</dd>
|
||||||
|
<dt>取流地址</dt>
|
||||||
|
<dd id="timing-url">-</dd>
|
||||||
|
<dt>打开流</dt>
|
||||||
|
<dd id="timing-open">-</dd>
|
||||||
|
<dt>首帧</dt>
|
||||||
|
<dd id="timing-frame">-</dd>
|
||||||
|
</dl>
|
||||||
|
</section>
|
||||||
|
|
||||||
|
<section>
|
||||||
|
<h2>检测结果</h2>
|
||||||
|
<div id="detections" class="detections empty">暂无目标</div>
|
||||||
|
</section>
|
||||||
|
</aside>
|
||||||
|
</main>
|
||||||
|
|
||||||
|
<script src="/static/app.js"></script>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
1
devicelist.env.example
Normal file
1
devicelist.env.example
Normal file
@@ -0,0 +1 @@
|
|||||||
|
DEVICE_NUM=33082500001327632958,33102300001327287520
|
||||||
6
info.md
Normal file
6
info.md
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
能力开放接口规范: 能力开放接口规范-20260602.docx
|
||||||
|
|
||||||
|
接口地址:https://apicapacity.51iwifi.com
|
||||||
|
appId: a3196a2af16ddc8f93cc
|
||||||
|
appSecret: 7f213e12590c3c0d
|
||||||
|
accessToken:临时生成,会过期
|
||||||
11
requirements.txt
Normal file
11
requirements.txt
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
fastapi
|
||||||
|
uvicorn[standard]
|
||||||
|
opencv-python
|
||||||
|
torch
|
||||||
|
torchvision
|
||||||
|
transformers
|
||||||
|
timm
|
||||||
|
pillow
|
||||||
|
numpy
|
||||||
|
python-dotenv
|
||||||
|
jinja2
|
||||||
481
tokenizer.md
Normal file
481
tokenizer.md
Normal file
@@ -0,0 +1,481 @@
|
|||||||
|
# DETR 的视觉 token 化过程说明
|
||||||
|
|
||||||
|
本文基于当前项目代码 `app/detector.py` 中的实现说明 DETR 的“token 化”过程。
|
||||||
|
|
||||||
|
当前代码使用的是 Hugging Face Transformers:
|
||||||
|
|
||||||
|
```python
|
||||||
|
self.processor = DetrImageProcessor.from_pretrained(model_name)
|
||||||
|
self.model = DetrForObjectDetection.from_pretrained(model_name)
|
||||||
|
```
|
||||||
|
|
||||||
|
默认模型为:
|
||||||
|
|
||||||
|
```text
|
||||||
|
facebook/detr-resnet-50
|
||||||
|
```
|
||||||
|
|
||||||
|
需要注意:DETR 这里没有文本 tokenizer。它处理的是图像,因此所谓“token 化”指的是把图像经过 CNN backbone 后得到的二维视觉特征图,展开成 Transformer 可以处理的一维视觉 token 序列。
|
||||||
|
|
||||||
|
## 当前代码中的入口
|
||||||
|
|
||||||
|
在 `app/detector.py` 中,检测入口是:
|
||||||
|
|
||||||
|
```python
|
||||||
|
image = Image.fromarray(frame_rgb)
|
||||||
|
inputs = self.processor(images=image, return_tensors="pt")
|
||||||
|
inputs = {key: value.to(self.device) for key, value in inputs.items()}
|
||||||
|
|
||||||
|
outputs = self.model(**inputs)
|
||||||
|
```
|
||||||
|
|
||||||
|
这里分成两部分:
|
||||||
|
|
||||||
|
1. `DetrImageProcessor`:做图像预处理。
|
||||||
|
2. `DetrForObjectDetection`:在模型内部完成视觉特征提取、flatten、位置编码、Transformer 编码解码和目标检测。
|
||||||
|
|
||||||
|
## 总体流程
|
||||||
|
|
||||||
|
完整流程可以理解为:
|
||||||
|
|
||||||
|
```text
|
||||||
|
OpenCV RGB 帧
|
||||||
|
↓
|
||||||
|
PIL Image
|
||||||
|
↓
|
||||||
|
DetrImageProcessor 图像预处理
|
||||||
|
↓
|
||||||
|
pixel_values: [batch, 3, H, W]
|
||||||
|
pixel_mask: [batch, H, W]
|
||||||
|
↓
|
||||||
|
ResNet-50 backbone 提取视觉特征
|
||||||
|
↓
|
||||||
|
feature map: [batch, 2048, H', W']
|
||||||
|
↓
|
||||||
|
1×1 convolution 投影通道
|
||||||
|
↓
|
||||||
|
projected feature map: [batch, 256, H', W']
|
||||||
|
↓
|
||||||
|
flatten 空间维度 H' × W'
|
||||||
|
↓
|
||||||
|
visual tokens: [batch, H'×W', 256]
|
||||||
|
↓
|
||||||
|
加入二维位置编码
|
||||||
|
↓
|
||||||
|
Transformer Encoder
|
||||||
|
↓
|
||||||
|
Object Queries + Transformer Decoder
|
||||||
|
↓
|
||||||
|
类别 logits + 边界框 boxes
|
||||||
|
↓
|
||||||
|
post_process_object_detection 还原到原图坐标
|
||||||
|
```
|
||||||
|
|
||||||
|
## 第 1 步:图像预处理
|
||||||
|
|
||||||
|
代码:
|
||||||
|
|
||||||
|
```python
|
||||||
|
inputs = self.processor(images=image, return_tensors="pt")
|
||||||
|
```
|
||||||
|
|
||||||
|
`DetrImageProcessor` 主要做这些事情:
|
||||||
|
|
||||||
|
- 调整图像尺寸。
|
||||||
|
- 转换为 PyTorch tensor。
|
||||||
|
- 归一化像素值。
|
||||||
|
- 生成 `pixel_values`。
|
||||||
|
- 必要时生成 `pixel_mask`,用于标记 padding 区域。
|
||||||
|
|
||||||
|
输出通常包含:
|
||||||
|
|
||||||
|
```python
|
||||||
|
{
|
||||||
|
"pixel_values": Tensor[batch, 3, H, W],
|
||||||
|
"pixel_mask": Tensor[batch, H, W]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
其中:
|
||||||
|
|
||||||
|
- `pixel_values` 是送入模型的图像张量。
|
||||||
|
- `pixel_mask` 用于告诉模型哪些区域是真实图像,哪些区域是 padding。
|
||||||
|
|
||||||
|
这一步不是文本 token 化,不会产生 `input_ids`、`attention_mask` 这类 NLP tokenizer 输出。
|
||||||
|
|
||||||
|
## 第 2 步:ResNet-50 提取特征图
|
||||||
|
|
||||||
|
模型内部首先使用 ResNet-50 backbone 处理图像:
|
||||||
|
|
||||||
|
```text
|
||||||
|
pixel_values: [batch, 3, H, W]
|
||||||
|
↓ ResNet-50
|
||||||
|
feature map: [batch, 2048, H', W']
|
||||||
|
```
|
||||||
|
|
||||||
|
`H'` 和 `W'` 是下采样后的空间尺寸。ResNet 通常会把图像下采样约 32 倍。
|
||||||
|
|
||||||
|
例如输入图像尺寸为:
|
||||||
|
|
||||||
|
```text
|
||||||
|
[batch, 3, 800, 1280]
|
||||||
|
```
|
||||||
|
|
||||||
|
经过 ResNet 后,特征图可能接近:
|
||||||
|
|
||||||
|
```text
|
||||||
|
[batch, 2048, 25, 40]
|
||||||
|
```
|
||||||
|
|
||||||
|
这里的每个空间位置 `(y, x)` 都是一个高层视觉特征向量。
|
||||||
|
|
||||||
|
## 第 3 步:1×1 卷积投影通道
|
||||||
|
|
||||||
|
ResNet 输出通道数通常是 `2048`,而 DETR Transformer 默认隐藏维度通常是 `256`。
|
||||||
|
|
||||||
|
因此模型会使用一个 `1×1 convolution` 做通道投影:
|
||||||
|
|
||||||
|
```text
|
||||||
|
[batch, 2048, H', W']
|
||||||
|
↓ 1×1 conv
|
||||||
|
[batch, 256, H', W']
|
||||||
|
```
|
||||||
|
|
||||||
|
这一步不会改变空间大小,只改变每个空间位置的特征维度。
|
||||||
|
|
||||||
|
可以理解为:
|
||||||
|
|
||||||
|
```text
|
||||||
|
每个网格点的 2048 维向量 → 256 维向量
|
||||||
|
```
|
||||||
|
|
||||||
|
## 第 4 步:flatten 成视觉 token 序列
|
||||||
|
|
||||||
|
这是“视觉 token 化”的核心步骤。
|
||||||
|
|
||||||
|
投影后的特征图形状是:
|
||||||
|
|
||||||
|
```text
|
||||||
|
[batch, 256, H', W']
|
||||||
|
```
|
||||||
|
|
||||||
|
模型会把二维空间维度 `H' × W'` 展开成一维序列:
|
||||||
|
|
||||||
|
```text
|
||||||
|
[batch, 256, H', W']
|
||||||
|
↓ flatten H' 和 W'
|
||||||
|
[batch, 256, H'×W']
|
||||||
|
↓ 调整维度顺序
|
||||||
|
[batch, H'×W', 256]
|
||||||
|
```
|
||||||
|
|
||||||
|
如果特征图是:
|
||||||
|
|
||||||
|
```text
|
||||||
|
[batch, 256, 25, 40]
|
||||||
|
```
|
||||||
|
|
||||||
|
那么 token 数是:
|
||||||
|
|
||||||
|
```text
|
||||||
|
25 × 40 = 1000
|
||||||
|
```
|
||||||
|
|
||||||
|
最终得到:
|
||||||
|
|
||||||
|
```text
|
||||||
|
[batch, 1000, 256]
|
||||||
|
```
|
||||||
|
|
||||||
|
也就是说:
|
||||||
|
|
||||||
|
```text
|
||||||
|
每个特征图网格位置 = 1 个视觉 token
|
||||||
|
每个视觉 token = 1 个 256 维向量
|
||||||
|
```
|
||||||
|
|
||||||
|
伪代码可以写成:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# x: [batch, 256, h, w]
|
||||||
|
x = x.flatten(2) # [batch, 256, h*w]
|
||||||
|
x = x.transpose(1, 2) # [batch, h*w, 256]
|
||||||
|
```
|
||||||
|
|
||||||
|
原始 DETR 论文和部分 PyTorch 实现中也常见 sequence-first 格式:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# x: [batch, 256, h, w]
|
||||||
|
x = x.flatten(2) # [batch, 256, h*w]
|
||||||
|
x = x.permute(2, 0, 1) # [h*w, batch, 256]
|
||||||
|
```
|
||||||
|
|
||||||
|
两种写法本质相同,只是 Transformer 接口期望的维度顺序不同。
|
||||||
|
|
||||||
|
## 第 5 步:加入二维位置编码
|
||||||
|
|
||||||
|
Transformer 本身不理解图像中的二维空间位置。
|
||||||
|
|
||||||
|
flatten 后,模型只看到一串 token:
|
||||||
|
|
||||||
|
```text
|
||||||
|
token_0, token_1, token_2, ..., token_N
|
||||||
|
```
|
||||||
|
|
||||||
|
如果不加入位置编码,模型不知道某个 token 原来位于图像左上角、中心还是右下角。
|
||||||
|
|
||||||
|
因此 DETR 会为特征图每个 `(y, x)` 位置生成二维位置编码:
|
||||||
|
|
||||||
|
```text
|
||||||
|
position encoding: [batch, 256, H', W']
|
||||||
|
```
|
||||||
|
|
||||||
|
然后同样 flatten:
|
||||||
|
|
||||||
|
```text
|
||||||
|
[batch, 256, H', W']
|
||||||
|
↓
|
||||||
|
[batch, H'×W', 256]
|
||||||
|
```
|
||||||
|
|
||||||
|
Transformer Encoder 接收的是:
|
||||||
|
|
||||||
|
```text
|
||||||
|
visual token + positional encoding
|
||||||
|
```
|
||||||
|
|
||||||
|
位置编码让模型知道 token 的空间布局。
|
||||||
|
|
||||||
|
## 第 6 步:Transformer Encoder 处理视觉 token
|
||||||
|
|
||||||
|
经过 flatten 和位置编码后,视觉 token 序列进入 Transformer Encoder:
|
||||||
|
|
||||||
|
```text
|
||||||
|
[batch, H'×W', 256]
|
||||||
|
↓ Transformer Encoder
|
||||||
|
[batch, H'×W', 256]
|
||||||
|
```
|
||||||
|
|
||||||
|
Encoder 会通过 self-attention 建模图像中不同区域之间的关系。
|
||||||
|
|
||||||
|
例如:
|
||||||
|
|
||||||
|
- 车头区域可以关注车身区域。
|
||||||
|
- 道路区域可以影响车辆判断。
|
||||||
|
- 远处小目标可以和周围上下文一起被理解。
|
||||||
|
|
||||||
|
## 第 7 步:Object Queries 和 Transformer Decoder
|
||||||
|
|
||||||
|
DETR 与传统检测器不同,它不是先生成大量 anchor box。
|
||||||
|
|
||||||
|
它使用一组可学习的 object queries。常见数量是:
|
||||||
|
|
||||||
|
```text
|
||||||
|
100 个 object queries
|
||||||
|
```
|
||||||
|
|
||||||
|
这些 queries 进入 Transformer Decoder,并关注 Encoder 输出的视觉 token:
|
||||||
|
|
||||||
|
```text
|
||||||
|
object queries: [batch, 100, 256]
|
||||||
|
encoder tokens: [batch, H'×W', 256]
|
||||||
|
↓ Transformer Decoder
|
||||||
|
object features: [batch, 100, 256]
|
||||||
|
```
|
||||||
|
|
||||||
|
每个 query 最终预测一个候选目标:
|
||||||
|
|
||||||
|
```text
|
||||||
|
类别 + 边界框
|
||||||
|
```
|
||||||
|
|
||||||
|
因此 DETR 输出通常可以理解为:
|
||||||
|
|
||||||
|
```text
|
||||||
|
最多 100 个候选目标
|
||||||
|
```
|
||||||
|
|
||||||
|
每个候选目标会包含:
|
||||||
|
|
||||||
|
- 类别 logits。
|
||||||
|
- 归一化边界框。
|
||||||
|
|
||||||
|
## 第 8 步:后处理成车辆检测结果
|
||||||
|
|
||||||
|
当前代码中的后处理是:
|
||||||
|
|
||||||
|
```python
|
||||||
|
target_sizes = torch.tensor([image.size[::-1]], device=self.device)
|
||||||
|
results = self.processor.post_process_object_detection(
|
||||||
|
outputs,
|
||||||
|
target_sizes=target_sizes,
|
||||||
|
threshold=self.confidence,
|
||||||
|
)[0]
|
||||||
|
```
|
||||||
|
|
||||||
|
这一步会:
|
||||||
|
|
||||||
|
- 把模型输出的归一化框还原为原图坐标。
|
||||||
|
- 根据置信度阈值过滤低分检测。
|
||||||
|
- 返回 `scores`、`labels`、`boxes`。
|
||||||
|
|
||||||
|
然后代码过滤车辆类别:
|
||||||
|
|
||||||
|
```python
|
||||||
|
label_name = self.model.config.id2label[label.item()]
|
||||||
|
if label_name not in self.vehicle_labels:
|
||||||
|
continue
|
||||||
|
```
|
||||||
|
|
||||||
|
默认车辆类别为:
|
||||||
|
|
||||||
|
```text
|
||||||
|
car, motorcycle, bus, truck, bicycle
|
||||||
|
```
|
||||||
|
|
||||||
|
最终输出格式:
|
||||||
|
|
||||||
|
```python
|
||||||
|
{
|
||||||
|
"label": "car",
|
||||||
|
"score": 0.9132,
|
||||||
|
"box": [x1, y1, x2, y2]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 与文本 tokenizer 的区别
|
||||||
|
|
||||||
|
| 对比项 | 文本 tokenizer | DETR 视觉 token 化 |
|
||||||
|
| --- | --- | --- |
|
||||||
|
| 输入 | 文本字符串 | 图像张量 |
|
||||||
|
| 输出 | token id 序列 | 视觉特征向量序列 |
|
||||||
|
| token 来源 | 词、子词、字符片段 | CNN 特征图空间网格 |
|
||||||
|
| token 内容 | 离散整数 id | 连续浮点向量 |
|
||||||
|
| 位置编码 | 一维位置编码 | 二维图像位置编码 |
|
||||||
|
| 当前代码中的类 | 无 | `DetrImageProcessor` + `DetrForObjectDetection` |
|
||||||
|
|
||||||
|
## 为什么说它是 patch-like features
|
||||||
|
|
||||||
|
它们可以叫做 `patch-like features`,但不是 ViT 那种直接切原图 patch。
|
||||||
|
|
||||||
|
ViT 通常是:
|
||||||
|
|
||||||
|
```text
|
||||||
|
原图 -> 切成 16×16 patch -> 线性投影 -> token
|
||||||
|
```
|
||||||
|
|
||||||
|
DETR-ResNet 是:
|
||||||
|
|
||||||
|
```text
|
||||||
|
原图 -> ResNet 特征图 -> 每个特征图网格点 -> token
|
||||||
|
```
|
||||||
|
|
||||||
|
因此 DETR 的每个 token:
|
||||||
|
|
||||||
|
- 对应特征图上的一个空间位置。
|
||||||
|
- 感受野来自 ResNet 深层网络。
|
||||||
|
- 通常覆盖原图中一片较大的区域。
|
||||||
|
- 相邻 token 的感受野会重叠。
|
||||||
|
|
||||||
|
## 一个具体尺寸例子
|
||||||
|
|
||||||
|
假设预处理后图像大小为:
|
||||||
|
|
||||||
|
```text
|
||||||
|
800 × 1280
|
||||||
|
```
|
||||||
|
|
||||||
|
ResNet-50 下采样约 32 倍:
|
||||||
|
|
||||||
|
```text
|
||||||
|
H' = 800 / 32 ≈ 25
|
||||||
|
W' = 1280 / 32 ≈ 40
|
||||||
|
```
|
||||||
|
|
||||||
|
Backbone 输出:
|
||||||
|
|
||||||
|
```text
|
||||||
|
[batch, 2048, 25, 40]
|
||||||
|
```
|
||||||
|
|
||||||
|
1×1 卷积投影:
|
||||||
|
|
||||||
|
```text
|
||||||
|
[batch, 256, 25, 40]
|
||||||
|
```
|
||||||
|
|
||||||
|
flatten:
|
||||||
|
|
||||||
|
```text
|
||||||
|
[batch, 1000, 256]
|
||||||
|
```
|
||||||
|
|
||||||
|
所以该图像大约会生成:
|
||||||
|
|
||||||
|
```text
|
||||||
|
1000 个视觉 token
|
||||||
|
```
|
||||||
|
|
||||||
|
每个 token 是:
|
||||||
|
|
||||||
|
```text
|
||||||
|
256 维连续向量
|
||||||
|
```
|
||||||
|
|
||||||
|
## 当前项目中的实际检测路径
|
||||||
|
|
||||||
|
当前项目的整体路径是:
|
||||||
|
|
||||||
|
```text
|
||||||
|
RTSP/HLS 视频帧
|
||||||
|
↓
|
||||||
|
OpenCV 读取 BGR frame
|
||||||
|
↓
|
||||||
|
cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||||
|
↓
|
||||||
|
Image.fromarray(frame_rgb)
|
||||||
|
↓
|
||||||
|
DetrImageProcessor 图像预处理
|
||||||
|
↓
|
||||||
|
DetrForObjectDetection 内部完成视觉 token 化和目标检测
|
||||||
|
↓
|
||||||
|
post_process_object_detection
|
||||||
|
↓
|
||||||
|
过滤车辆类别
|
||||||
|
↓
|
||||||
|
OpenCV OSD 画框
|
||||||
|
↓
|
||||||
|
FastAPI /video 输出动态打标画面
|
||||||
|
```
|
||||||
|
|
||||||
|
对应代码位置:
|
||||||
|
|
||||||
|
- `app/detector.py`:DETR 图像预处理、模型推理、后处理。
|
||||||
|
- `app/stream_worker.py`:视频帧读取、推理调用、OSD 画框。
|
||||||
|
- `app/main.py`:MJPEG 视频流和检测结果接口。
|
||||||
|
|
||||||
|
## 小结
|
||||||
|
|
||||||
|
当前项目中的 DETR 不包含文本 tokenizer。
|
||||||
|
|
||||||
|
它的视觉 token 化具体是:
|
||||||
|
|
||||||
|
```text
|
||||||
|
图像经过 ResNet-50 得到二维特征图
|
||||||
|
↓
|
||||||
|
1×1 卷积把通道投影到 Transformer hidden size
|
||||||
|
↓
|
||||||
|
把 H' × W' 个空间位置 flatten 成 H'×W' 个视觉 token
|
||||||
|
↓
|
||||||
|
给每个 token 加二维位置编码
|
||||||
|
↓
|
||||||
|
送入 Transformer Encoder
|
||||||
|
```
|
||||||
|
|
||||||
|
因此,“token 数”主要由输入分辨率和 backbone 下采样比例决定:
|
||||||
|
|
||||||
|
```text
|
||||||
|
视觉 token 数 ≈ (输入高度 / 32) × (输入宽度 / 32)
|
||||||
|
```
|
||||||
|
|
||||||
|
实际数量会受到图像预处理 resize、padding 和特征图尺寸取整影响。
|
||||||
BIN
能力开放接口规范-20260602.docx
Normal file
BIN
能力开放接口规范-20260602.docx
Normal file
Binary file not shown.
Reference in New Issue
Block a user