first commit
This commit is contained in:
13
.env.example
Normal file
13
.env.example
Normal file
@@ -0,0 +1,13 @@
|
||||
API_BASE_URL=https://apicapacity.51iwifi.com
|
||||
APP_ID=your_app_id
|
||||
APP_SECRET=your_app_secret
|
||||
DEVICE_LIST_PATH=devicelist.env
|
||||
DEVICE_ACCOUNT=21cn
|
||||
STREAM_METHOD=capacity.geye.device.devUrl.get
|
||||
STREAM_URL=
|
||||
DETR_CONFIDENCE=0.6
|
||||
FRAME_SKIP=5
|
||||
DETR_MODEL=facebook/detr-resnet-50
|
||||
JPEG_QUALITY=80
|
||||
RESIZE_WIDTH=960
|
||||
VEHICLE_LABELS=car,motorcycle,bus,truck,bicycle
|
||||
15
.gitignore
vendored
Normal file
15
.gitignore
vendored
Normal file
@@ -0,0 +1,15 @@
|
||||
.venv/
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
.pytest_cache/
|
||||
.mypy_cache/
|
||||
.ruff_cache/
|
||||
.coverage
|
||||
htmlcov/
|
||||
build/
|
||||
dist/
|
||||
*.egg-info/
|
||||
.env
|
||||
devicelist.env
|
||||
.DS_Store
|
||||
137
README.md
Normal file
137
README.md
Normal file
@@ -0,0 +1,137 @@
|
||||
# macOS DETR 车辆动态打标
|
||||
|
||||
这是一个参考 `../VideoPipe` 思路实现的 macOS 友好版 Python 项目:
|
||||
|
||||
```text
|
||||
RTSP/HLS/本地文件源节点 -> DETR 车辆检测推理节点 -> OSD 画框节点 -> FastAPI 远程输出节点
|
||||
```
|
||||
|
||||
本项目不直接移植 VideoPipe 的 C++、TensorRT、CUDA、GStreamer RTSP Server 工程,而是保留它的“源节点 → 推理节点 → OSD → 输出节点”思想,使用 Python、OpenCV、PyTorch、Transformers DETR 和 FastAPI,方便在 Mac mini / Apple Silicon 上运行。
|
||||
|
||||
## 架构和推理流程
|
||||
|
||||
```mermaid
|
||||
flowchart LR
|
||||
A[devicelist.env 摄像头列表] --> B[DeviceManager 默认选择第一个设备]
|
||||
B --> C[能力开放接口获取 RTSP 地址]
|
||||
C --> D[OpenCV VideoCapture 读取视频流]
|
||||
D --> E[StreamWorker 抽帧]
|
||||
E --> F[DetrVehicleDetector 车辆检测]
|
||||
F --> G[OpenCV OSD 画框]
|
||||
G --> H[FastAPI MJPEG /video]
|
||||
G --> I[WebSocket /ws/detections]
|
||||
H --> J[浏览器动态打标画面]
|
||||
I --> J
|
||||
```
|
||||
|
||||
```mermaid
|
||||
flowchart TD
|
||||
A[读取一帧 BGR 图像] --> B{是否达到 FRAME_SKIP}
|
||||
B -- 否 --> F[复用上一轮检测结果]
|
||||
B -- 是 --> C[转换为 RGB]
|
||||
C --> D[Transformers ImageProcessor 预处理]
|
||||
D --> E[PyTorch DETR 推理]
|
||||
E --> G[按 DETR_CONFIDENCE 过滤车辆类别]
|
||||
F --> H[绘制检测框和标签]
|
||||
G --> H
|
||||
H --> I[JPEG 编码]
|
||||
I --> J[FastAPI 输出到远程浏览器]
|
||||
```
|
||||
|
||||
## 安装
|
||||
|
||||
```bash
|
||||
python3 -m venv .venv
|
||||
source .venv/bin/activate
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## 配置
|
||||
|
||||
视频流地址通过环境变量传入,不要把 RTSP/HLS token 写进代码或文档。
|
||||
|
||||
```bash
|
||||
export STREAM_URL='rtsp://你的摄像头或视频流地址'
|
||||
export DETR_CONFIDENCE=0.6
|
||||
export FRAME_SKIP=5
|
||||
```
|
||||
|
||||
可选配置:
|
||||
|
||||
```bash
|
||||
export DETR_MODEL='facebook/detr-resnet-50'
|
||||
export JPEG_QUALITY=80
|
||||
export RESIZE_WIDTH=960
|
||||
export VEHICLE_LABELS='car,motorcycle,bus,truck,bicycle'
|
||||
```
|
||||
|
||||
配置说明:
|
||||
|
||||
- `STREAM_URL`:RTSP、HLS 或本地视频文件路径,必填。
|
||||
- `DETR_CONFIDENCE`:检测置信度阈值,默认 `0.6`。
|
||||
- `FRAME_SKIP`:每隔多少帧做一次 DETR 推理,默认 `3`。数值越大,负载越低。
|
||||
- `RESIZE_WIDTH`:可选,设置后会按宽度等比缩小视频帧,降低 Mac mini 压力。
|
||||
- `VEHICLE_LABELS`:需要保留的 COCO 车辆类别。
|
||||
|
||||
## 运行
|
||||
|
||||
```bash
|
||||
uvicorn app.main:app --host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||
本机访问:
|
||||
|
||||
```text
|
||||
http://127.0.0.1:8000
|
||||
```
|
||||
|
||||
局域网远端访问:
|
||||
|
||||
```text
|
||||
http://<Mac-IP>:8000
|
||||
```
|
||||
|
||||
## 接口
|
||||
|
||||
- `GET /`:中文浏览器看板。
|
||||
- `GET /video`:MJPEG 动态打标视频流。
|
||||
- `GET /detections`:最近一帧检测结果 JSON。
|
||||
- `GET /status`:运行状态、模型、设备、帧号、FPS。
|
||||
- `WS /ws/detections`:实时推送检测元数据。
|
||||
|
||||
## RTSP 和 HLS 选择
|
||||
|
||||
建议优先使用 RTSP,实时性更好。HLS `.m3u8` 通常更稳定,但会有几秒到十几秒延迟。
|
||||
|
||||
OpenCV 需要带 FFmpeg 支持才能打开很多 RTSP/HLS 地址。如果无法打开远端流,可以先用本地视频文件验证项目,再考虑增加 PyAV 或 ffmpeg-python 输入后端。
|
||||
|
||||
## macOS 远端访问检查
|
||||
|
||||
如果其他机器访问不到页面,检查 macOS 防火墙,并确认服务监听在 `0.0.0.0`。
|
||||
|
||||
```bash
|
||||
curl http://127.0.0.1:8000/status
|
||||
curl http://<Mac-IP>:8000/status
|
||||
lsof -i :8000
|
||||
```
|
||||
|
||||
## 性能建议
|
||||
|
||||
DETR 原版模型较重。Mac mini 上建议先这样跑:
|
||||
|
||||
```bash
|
||||
export FRAME_SKIP=5
|
||||
export RESIZE_WIDTH=960
|
||||
```
|
||||
|
||||
Apple Silicon 上会优先使用 PyTorch MPS;不可用时自动回退到 CPU。
|
||||
|
||||
## 与 VideoPipe 的对应关系
|
||||
|
||||
| VideoPipe 概念 | 本项目实现 |
|
||||
| --- | --- |
|
||||
| 源节点 | `StreamWorker` 中的 OpenCV `VideoCapture` |
|
||||
| 推理节点 | `DetrVehicleDetector` |
|
||||
| OSD 节点 | `StreamWorker._draw()` |
|
||||
| 输出节点 | FastAPI `/video`、`/detections`、`/ws/detections` |
|
||||
| 分析看板 | 浏览器中文 Dashboard |
|
||||
0
app/__init__.py
Normal file
0
app/__init__.py
Normal file
123
app/capacity_api.py
Normal file
123
app/capacity_api.py
Normal file
@@ -0,0 +1,123 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import time
|
||||
import urllib.parse
|
||||
import urllib.request
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class Token:
|
||||
access_token: str
|
||||
expires_at: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamUrlResult:
|
||||
url: str
|
||||
timings: dict[str, float]
|
||||
|
||||
|
||||
class CapacityApiClient:
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
app_id: str,
|
||||
app_secret: str,
|
||||
account: str,
|
||||
method: str,
|
||||
timeout: int = 20,
|
||||
):
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.app_id = app_id
|
||||
self.app_secret = app_secret
|
||||
self.account = account
|
||||
self.method = method
|
||||
self.timeout = timeout
|
||||
self.token: Token | None = None
|
||||
|
||||
def get_stream_url(self, device_num: str) -> str:
|
||||
return self.get_stream_url_details(device_num).url
|
||||
|
||||
def get_stream_url_details(self, device_num: str) -> StreamUrlResult:
|
||||
timings: dict[str, float] = {}
|
||||
started = time.monotonic()
|
||||
access_token = self._get_access_token(timings)
|
||||
timings["token_ms"] = round((time.monotonic() - started) * 1000, 2)
|
||||
started = time.monotonic()
|
||||
business_params = {
|
||||
"account": self.account,
|
||||
"deviceNum": device_num,
|
||||
"isSubStream": 0,
|
||||
"networkType": 1,
|
||||
"urlType": 1,
|
||||
}
|
||||
# 接口文档要求业务参数整体放进 params JSON 字符串后再参与签名。
|
||||
params = {
|
||||
"accessToken": access_token,
|
||||
"appId": self.app_id,
|
||||
"method": self.method,
|
||||
"params": json.dumps(business_params, ensure_ascii=False, separators=(",", ":")),
|
||||
"timestamp": self._timestamp(),
|
||||
"v": "1.0.0",
|
||||
}
|
||||
params["sign"] = self._sign(params)
|
||||
timings["sign_ms"] = round((time.monotonic() - started) * 1000, 2)
|
||||
started = time.monotonic()
|
||||
data = self._get_json(f"{self.base_url}/rest", params)
|
||||
timings["stream_url_ms"] = round((time.monotonic() - started) * 1000, 2)
|
||||
if data.get("errorCode") != "0":
|
||||
raise RuntimeError(data.get("errorMsg") or f"播放地址接口返回错误 {data.get('errorCode')}")
|
||||
|
||||
payload = data.get("data") or {}
|
||||
stream_url = payload.get("rtspUrl") or payload.get("rtspUri")
|
||||
if not stream_url:
|
||||
raise RuntimeError("播放地址接口未返回 RTSP 地址")
|
||||
return StreamUrlResult(stream_url, timings)
|
||||
|
||||
def _get_access_token(self, timings: dict[str, float] | None = None) -> str:
|
||||
if self.token and time.time() < self.token.expires_at - 300:
|
||||
if timings is not None:
|
||||
timings["token_cache"] = 1
|
||||
return self.token.access_token
|
||||
|
||||
data = self._get_json(
|
||||
f"{self.base_url}/oauth/token",
|
||||
{
|
||||
"grantType": "client_credential",
|
||||
"appId": self.app_id,
|
||||
"appSecret": self.app_secret,
|
||||
},
|
||||
)
|
||||
if data.get("errorCode") != "0":
|
||||
raise RuntimeError(data.get("errorMsg") or f"获取 accessToken 失败 {data.get('errorCode')}")
|
||||
|
||||
payload = data.get("data") or {}
|
||||
access_token = payload.get("accessToken")
|
||||
if not access_token:
|
||||
raise RuntimeError("token 接口未返回 accessToken")
|
||||
|
||||
self.token = Token(
|
||||
access_token=access_token,
|
||||
expires_at=time.time() + int(payload.get("expiresIn") or 604800),
|
||||
)
|
||||
return access_token
|
||||
|
||||
def _get_json(self, url: str, params: dict[str, Any]) -> dict[str, Any]:
|
||||
query = urllib.parse.urlencode(params)
|
||||
with urllib.request.urlopen(f"{url}?{query}", timeout=self.timeout) as response:
|
||||
body = response.read().decode("utf-8")
|
||||
return json.loads(body)
|
||||
|
||||
def _sign(self, params: dict[str, Any]) -> str:
|
||||
# 签名规则:appSecret + 按 ASCII key 排序后的 key/value + appSecret。
|
||||
raw = self.app_secret + "".join(f"{key}{params[key]}" for key in sorted(params)) + self.app_secret
|
||||
return hashlib.md5(raw.encode("utf-8")).hexdigest().upper()
|
||||
|
||||
@staticmethod
|
||||
def _timestamp() -> str:
|
||||
return datetime.now(timezone(timedelta(hours=8))).strftime("%Y-%m-%d %H:%M:%S")
|
||||
66
app/config.py
Normal file
66
app/config.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Settings:
|
||||
stream_url: str
|
||||
api_base_url: str
|
||||
app_id: str
|
||||
app_secret: str
|
||||
device_list_path: str
|
||||
device_account: str
|
||||
stream_method: str
|
||||
detr_model: str
|
||||
confidence: float
|
||||
frame_skip: int
|
||||
jpeg_quality: int
|
||||
resize_width: int | None
|
||||
vehicle_labels: set[str]
|
||||
|
||||
|
||||
def _optional_int(name: str) -> int | None:
|
||||
value = os.getenv(name, "").strip()
|
||||
if not value:
|
||||
return None
|
||||
parsed = int(value)
|
||||
return parsed if parsed > 0 else None
|
||||
|
||||
|
||||
def load_settings() -> Settings:
|
||||
stream_url = os.getenv("STREAM_URL", "").strip()
|
||||
|
||||
vehicle_labels = {
|
||||
item.strip()
|
||||
for item in os.getenv("VEHICLE_LABELS", "car,motorcycle,bus,truck,bicycle").split(",")
|
||||
if item.strip()
|
||||
}
|
||||
|
||||
return Settings(
|
||||
stream_url=stream_url,
|
||||
api_base_url=os.getenv("API_BASE_URL", "https://apicapacity.51iwifi.com"),
|
||||
app_id=os.getenv("APP_ID", ""),
|
||||
app_secret=os.getenv("APP_SECRET", ""),
|
||||
device_list_path=os.getenv("DEVICE_LIST_PATH", "devicelist.env"),
|
||||
device_account=os.getenv("DEVICE_ACCOUNT", "21cn"),
|
||||
stream_method=os.getenv("STREAM_METHOD", "capacity.geye.device.devUrl.get"),
|
||||
detr_model=os.getenv("DETR_MODEL", "facebook/detr-resnet-50"),
|
||||
confidence=float(os.getenv("DETR_CONFIDENCE", "0.6")),
|
||||
frame_skip=max(1, int(os.getenv("FRAME_SKIP", "3"))),
|
||||
jpeg_quality=min(100, max(1, int(os.getenv("JPEG_QUALITY", "80")))),
|
||||
resize_width=_optional_int("RESIZE_WIDTH"),
|
||||
vehicle_labels=vehicle_labels,
|
||||
)
|
||||
|
||||
|
||||
def mask_url(url: str) -> str:
|
||||
if "token=" not in url:
|
||||
return url
|
||||
prefix, _ = url.split("token=", 1)
|
||||
return f"{prefix}token=***"
|
||||
60
app/detector.py
Normal file
60
app/detector.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import DetrForObjectDetection, DetrImageProcessor
|
||||
|
||||
|
||||
class DetrVehicleDetector:
|
||||
def __init__(self, model_name: str, confidence: float, vehicle_labels: set[str]):
|
||||
self.confidence = confidence
|
||||
self.vehicle_labels = vehicle_labels
|
||||
self.device = self._select_device()
|
||||
self.processor = DetrImageProcessor.from_pretrained(model_name)
|
||||
self.model = DetrForObjectDetection.from_pretrained(model_name)
|
||||
self.model.to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
@staticmethod
|
||||
def _select_device() -> torch.device:
|
||||
if torch.backends.mps.is_available():
|
||||
return torch.device("mps")
|
||||
return torch.device("cpu")
|
||||
|
||||
@property
|
||||
def device_name(self) -> str:
|
||||
return self.device.type
|
||||
|
||||
@torch.no_grad()
|
||||
def detect(self, frame_rgb: Any) -> list[dict[str, Any]]:
|
||||
image = Image.fromarray(frame_rgb)
|
||||
inputs = self.processor(images=image, return_tensors="pt")
|
||||
inputs = {key: value.to(self.device) for key, value in inputs.items()}
|
||||
|
||||
outputs = self.model(**inputs)
|
||||
# DETR 后处理需要原图尺寸,PIL size 是 (宽, 高),这里转成 (高, 宽)。
|
||||
target_sizes = torch.tensor([image.size[::-1]], device=self.device)
|
||||
results = self.processor.post_process_object_detection(
|
||||
outputs,
|
||||
target_sizes=target_sizes,
|
||||
threshold=self.confidence,
|
||||
)[0]
|
||||
|
||||
detections = []
|
||||
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
|
||||
label_name = self.model.config.id2label[label.item()]
|
||||
if label_name not in self.vehicle_labels:
|
||||
continue
|
||||
|
||||
x1, y1, x2, y2 = box.detach().cpu().numpy().astype(int).tolist()
|
||||
detections.append(
|
||||
{
|
||||
"label": label_name,
|
||||
"score": round(float(score.detach().cpu()), 4),
|
||||
"box": [x1, y1, x2, y2],
|
||||
}
|
||||
)
|
||||
|
||||
return detections
|
||||
91
app/device_manager.py
Normal file
91
app/device_manager.py
Normal file
@@ -0,0 +1,91 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from app.capacity_api import CapacityApiClient
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Device:
|
||||
name: str
|
||||
device_num: str
|
||||
|
||||
|
||||
class DeviceManager:
|
||||
def __init__(self, path: str, api_client: CapacityApiClient, fallback_url: str = ""):
|
||||
self.devices = self._load_devices(path)
|
||||
self.api_client = api_client
|
||||
self.fallback_url = fallback_url
|
||||
self.lock = threading.Lock()
|
||||
self.current_device_num = self.devices[0].device_num if self.devices else ""
|
||||
self.current_url = fallback_url
|
||||
self.timings: dict[str, float] = {}
|
||||
self.updated_at = 0.0
|
||||
self.version = 0
|
||||
|
||||
def set_current_device(self, device_num: str) -> int:
|
||||
if device_num not in {device.device_num for device in self.devices}:
|
||||
raise ValueError("设备不在 devicelist.env 中")
|
||||
with self.lock:
|
||||
self.current_device_num = device_num
|
||||
self.current_url = ""
|
||||
self.timings = {}
|
||||
self.updated_at = time.time()
|
||||
self.version += 1
|
||||
return self.version
|
||||
|
||||
def resolve_stream_url(self) -> str:
|
||||
with self.lock:
|
||||
device_num = self.current_device_num
|
||||
version = self.version
|
||||
if not device_num:
|
||||
if self.fallback_url:
|
||||
return self.fallback_url
|
||||
raise RuntimeError("devicelist.env 中没有可用设备号")
|
||||
|
||||
result = self.api_client.get_stream_url_details(device_num)
|
||||
with self.lock:
|
||||
# 避免旧摄像头的慢接口响应覆盖用户刚切换的新选择。
|
||||
if version != self.version or device_num != self.current_device_num:
|
||||
return self.current_url
|
||||
self.current_url = result.url
|
||||
self.timings = dict(result.timings)
|
||||
self.updated_at = time.time()
|
||||
return result.url
|
||||
|
||||
def get_snapshot(self) -> dict[str, Any]:
|
||||
with self.lock:
|
||||
return {
|
||||
"devices": [device.__dict__ for device in self.devices],
|
||||
"current_device_num": self.current_device_num,
|
||||
"current_url": self.current_url,
|
||||
"source_timings": dict(self.timings),
|
||||
"source_updated_at": self.updated_at,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _load_devices(path: str) -> list[Device]:
|
||||
devices: list[Device] = []
|
||||
file_path = Path(path)
|
||||
if not file_path.exists():
|
||||
return devices
|
||||
|
||||
for line in file_path.read_text(encoding="utf-8").splitlines():
|
||||
stripped = line.strip()
|
||||
if not stripped or stripped.startswith("#"):
|
||||
continue
|
||||
if "=" in stripped:
|
||||
name, value = stripped.split("=", 1)
|
||||
values = [item.strip() for item in value.split(",") if item.strip()]
|
||||
display_name = name.strip()
|
||||
else:
|
||||
values = [stripped]
|
||||
display_name = "摄像头"
|
||||
for device_num in values:
|
||||
name = display_name if len(values) == 1 else f"摄像头 {len(devices) + 1}"
|
||||
devices.append(Device(name, device_num))
|
||||
return devices
|
||||
163
app/main.py
Normal file
163
app/main.py
Normal file
@@ -0,0 +1,163 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Request, WebSocket, WebSocketDisconnect
|
||||
from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.templating import Jinja2Templates
|
||||
|
||||
from app.capacity_api import CapacityApiClient
|
||||
from app.config import load_settings, mask_url
|
||||
from app.detector import DetrVehicleDetector
|
||||
from app.device_manager import DeviceManager
|
||||
from app.stream_worker import StreamWorker
|
||||
|
||||
settings = load_settings()
|
||||
api_client = CapacityApiClient(
|
||||
base_url=settings.api_base_url,
|
||||
app_id=settings.app_id,
|
||||
app_secret=settings.app_secret,
|
||||
account=settings.device_account,
|
||||
method=settings.stream_method,
|
||||
)
|
||||
detector = DetrVehicleDetector(
|
||||
model_name=settings.detr_model,
|
||||
confidence=settings.confidence,
|
||||
vehicle_labels=settings.vehicle_labels,
|
||||
)
|
||||
device_manager = DeviceManager(
|
||||
path=settings.device_list_path,
|
||||
api_client=api_client,
|
||||
fallback_url=settings.stream_url,
|
||||
)
|
||||
def resolve_stream_url() -> str:
|
||||
return device_manager.resolve_stream_url()
|
||||
|
||||
|
||||
worker = StreamWorker(
|
||||
stream_url=resolve_stream_url,
|
||||
detector=detector,
|
||||
frame_skip=settings.frame_skip,
|
||||
jpeg_quality=settings.jpeg_quality,
|
||||
resize_width=settings.resize_width,
|
||||
)
|
||||
|
||||
app = FastAPI(title="DETR 动态打标")
|
||||
app.mount("/static", StaticFiles(directory="app/static"), name="static")
|
||||
templates = Jinja2Templates(directory="app/templates")
|
||||
|
||||
|
||||
def display_model_name(model_name: str) -> str:
|
||||
return model_name.rsplit("/", 1)[-1]
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
def startup() -> None:
|
||||
worker.start()
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
def shutdown() -> None:
|
||||
worker.stop()
|
||||
|
||||
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
def index(request: Request) -> HTMLResponse:
|
||||
return templates.TemplateResponse(
|
||||
"index.html",
|
||||
{
|
||||
"request": request,
|
||||
"model": display_model_name(settings.detr_model),
|
||||
"device": detector.device_name,
|
||||
"stream_url": f"设备号:{device_manager.get_snapshot()['current_device_num']}",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@app.get("/video")
|
||||
def video() -> StreamingResponse:
|
||||
async def generate():
|
||||
while True:
|
||||
frame = worker.get_jpeg()
|
||||
if frame is None:
|
||||
await asyncio.sleep(0.1)
|
||||
continue
|
||||
|
||||
yield b"--frame\r\nContent-Type: image/jpeg\r\n\r\n" + frame + b"\r\n"
|
||||
await asyncio.sleep(0.03)
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="multipart/x-mixed-replace; boundary=frame",
|
||||
)
|
||||
|
||||
|
||||
@app.get("/detections")
|
||||
def detections() -> JSONResponse:
|
||||
snapshot = worker.get_snapshot()
|
||||
return JSONResponse(
|
||||
{
|
||||
"frame_id": snapshot["frame_id"],
|
||||
"updated_at": snapshot["updated_at"],
|
||||
"detections": snapshot["detections"],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@app.get("/status")
|
||||
def status() -> JSONResponse:
|
||||
snapshot = worker.get_snapshot()
|
||||
device_snapshot = device_manager.get_snapshot()
|
||||
timings = dict(device_snapshot["source_timings"])
|
||||
# 合并取流地址和 OpenCV 读流耗时,前端按同一个 timings 对象展示。
|
||||
timings.update(snapshot["timings"])
|
||||
return JSONResponse(
|
||||
{
|
||||
"running": snapshot["running"],
|
||||
"connected": snapshot["connected"],
|
||||
"frame_id": snapshot["frame_id"],
|
||||
"updated_at": snapshot["updated_at"],
|
||||
"fps": snapshot["fps"],
|
||||
"error": snapshot["error"],
|
||||
"source": mask_url(device_snapshot["current_url"]) if device_snapshot["current_url"] else "等待获取播放地址",
|
||||
"model": display_model_name(settings.detr_model),
|
||||
"device": detector.device_name,
|
||||
"frame_skip": settings.frame_skip,
|
||||
"confidence": settings.confidence,
|
||||
"devices": device_snapshot["devices"],
|
||||
"current_device_num": device_snapshot["current_device_num"],
|
||||
"timings": timings,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@app.post("/devices/{device_num}")
|
||||
def switch_device(device_num: str) -> JSONResponse:
|
||||
try:
|
||||
device_manager.set_current_device(device_num)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=404, detail=str(exc)) from exc
|
||||
worker.reconnect()
|
||||
return JSONResponse({"current_device_num": device_num})
|
||||
|
||||
|
||||
@app.websocket("/ws/detections")
|
||||
async def websocket_detections(websocket: WebSocket) -> None:
|
||||
await websocket.accept()
|
||||
try:
|
||||
while True:
|
||||
data = worker.get_snapshot()
|
||||
device_snapshot = device_manager.get_snapshot()
|
||||
data.update(
|
||||
{
|
||||
"devices": device_snapshot["devices"],
|
||||
"current_device_num": device_snapshot["current_device_num"],
|
||||
"source": mask_url(device_snapshot["current_url"]) if device_snapshot["current_url"] else "等待获取播放地址",
|
||||
"timings": {**device_snapshot["source_timings"], **data["timings"]},
|
||||
}
|
||||
)
|
||||
await websocket.send_json(data)
|
||||
await asyncio.sleep(0.5)
|
||||
except WebSocketDisconnect:
|
||||
return
|
||||
142
app/static/app.js
Normal file
142
app/static/app.js
Normal file
@@ -0,0 +1,142 @@
|
||||
const connection = document.querySelector("#connection");
|
||||
const detectionsEl = document.querySelector("#detections");
|
||||
const frameIdEl = document.querySelector("#frame-id");
|
||||
const fpsEl = document.querySelector("#fps");
|
||||
const errorEl = document.querySelector("#error");
|
||||
const sourceEl = document.querySelector("#source");
|
||||
const deviceSelect = document.querySelector("#device-select");
|
||||
const timingTokenEl = document.querySelector("#timing-token");
|
||||
const timingSignEl = document.querySelector("#timing-sign");
|
||||
const timingUrlEl = document.querySelector("#timing-url");
|
||||
const timingOpenEl = document.querySelector("#timing-open");
|
||||
const timingFrameEl = document.querySelector("#timing-frame");
|
||||
|
||||
let selectedDevice = "";
|
||||
let pendingDevice = "";
|
||||
let devicesSignature = "";
|
||||
|
||||
function setConnection(online, text) {
|
||||
connection.textContent = text;
|
||||
connection.classList.toggle("online", online);
|
||||
connection.classList.toggle("offline", !online);
|
||||
}
|
||||
|
||||
function formatMs(value) {
|
||||
if (value === undefined || value === null || value === 0) {
|
||||
return "-";
|
||||
}
|
||||
return `${Number(value).toFixed(2)} ms`;
|
||||
}
|
||||
|
||||
function renderDevices(devices, currentDeviceNum) {
|
||||
if (!devices.length) {
|
||||
deviceSelect.innerHTML = '<option value="">未配置摄像头</option>';
|
||||
deviceSelect.disabled = true;
|
||||
return;
|
||||
}
|
||||
|
||||
const displayDevice = pendingDevice || currentDeviceNum;
|
||||
const nextSignature = `${displayDevice}|${devices.map((device) => device.device_num).join(",")}`;
|
||||
if (nextSignature === devicesSignature) {
|
||||
return;
|
||||
}
|
||||
|
||||
selectedDevice = displayDevice;
|
||||
devicesSignature = nextSignature;
|
||||
deviceSelect.innerHTML = devices
|
||||
.map((device) => {
|
||||
const selected = device.device_num === displayDevice ? "selected" : "";
|
||||
return `<option value="${device.device_num}" ${selected}>${device.name} · ${device.device_num}</option>`;
|
||||
})
|
||||
.join("");
|
||||
}
|
||||
|
||||
function renderTimings(timings) {
|
||||
timingTokenEl.textContent = timings?.token_cache ? "缓存" : formatMs(timings?.token_ms);
|
||||
timingSignEl.textContent = formatMs(timings?.sign_ms);
|
||||
timingUrlEl.textContent = formatMs(timings?.stream_url_ms);
|
||||
timingOpenEl.textContent = formatMs(timings?.open_ms);
|
||||
timingFrameEl.textContent = formatMs(timings?.first_frame_ms);
|
||||
}
|
||||
|
||||
function renderDetections(detections) {
|
||||
if (!detections.length) {
|
||||
detectionsEl.className = "detections empty";
|
||||
detectionsEl.textContent = "暂无目标";
|
||||
return;
|
||||
}
|
||||
|
||||
detectionsEl.className = "detections";
|
||||
detectionsEl.innerHTML = detections
|
||||
.map((det) => {
|
||||
const score = `${(det.score * 100).toFixed(1)}%`;
|
||||
const box = det.box.join(", ");
|
||||
return `
|
||||
<div class="det-item">
|
||||
<div class="det-title">
|
||||
<span>${det.label}</span>
|
||||
<span>${score}</span>
|
||||
</div>
|
||||
<div class="det-box">box: [${box}]</div>
|
||||
</div>
|
||||
`;
|
||||
})
|
||||
.join("");
|
||||
}
|
||||
|
||||
async function switchDevice(deviceNum) {
|
||||
pendingDevice = deviceNum;
|
||||
devicesSignature = "";
|
||||
setConnection(false, "切换中");
|
||||
const response = await fetch(`/devices/${encodeURIComponent(deviceNum)}`, { method: "POST" });
|
||||
if (!response.ok) {
|
||||
throw new Error("切换摄像头失败");
|
||||
}
|
||||
const video = document.querySelector("#video");
|
||||
video.src = `/video?t=${Date.now()}`;
|
||||
}
|
||||
|
||||
function connectWebSocket() {
|
||||
const protocol = window.location.protocol === "https:" ? "wss" : "ws";
|
||||
const ws = new WebSocket(`${protocol}://${window.location.host}/ws/detections`);
|
||||
|
||||
ws.addEventListener("open", () => setConnection(true, "已连接"));
|
||||
|
||||
ws.addEventListener("message", (event) => {
|
||||
const data = JSON.parse(event.data);
|
||||
frameIdEl.textContent = data.frame_id ?? "-";
|
||||
fpsEl.textContent = data.fps ?? "-";
|
||||
errorEl.textContent = data.error || (data.connected ? "正常" : "未连接");
|
||||
sourceEl.textContent = data.source || "-";
|
||||
setConnection(Boolean(data.connected), data.connected ? "已连接" : "重连中");
|
||||
if (pendingDevice && data.current_device_num === pendingDevice) {
|
||||
pendingDevice = "";
|
||||
deviceSelect.disabled = false;
|
||||
}
|
||||
renderDevices(data.devices || [], data.current_device_num || "");
|
||||
renderTimings(data.timings || {});
|
||||
renderDetections(data.detections || []);
|
||||
});
|
||||
|
||||
ws.addEventListener("close", () => {
|
||||
setConnection(false, "已断开");
|
||||
setTimeout(connectWebSocket, 1500);
|
||||
});
|
||||
|
||||
ws.addEventListener("error", () => {
|
||||
setConnection(false, "连接错误");
|
||||
ws.close();
|
||||
});
|
||||
}
|
||||
|
||||
deviceSelect.addEventListener("change", (event) => {
|
||||
selectedDevice = event.target.value;
|
||||
switchDevice(event.target.value).catch(() => {
|
||||
pendingDevice = "";
|
||||
devicesSignature = "";
|
||||
setConnection(false, "切换失败");
|
||||
deviceSelect.disabled = false;
|
||||
});
|
||||
});
|
||||
|
||||
connectWebSocket();
|
||||
209
app/static/style.css
Normal file
209
app/static/style.css
Normal file
@@ -0,0 +1,209 @@
|
||||
:root {
|
||||
color-scheme: dark;
|
||||
--bg: #0c1017;
|
||||
--panel: #151b26;
|
||||
--panel-2: #101722;
|
||||
--text: #eef4ff;
|
||||
--muted: #8f9db3;
|
||||
--line: #273246;
|
||||
--green: #2ee887;
|
||||
--yellow: #f7c948;
|
||||
--red: #ff6b6b;
|
||||
}
|
||||
|
||||
* {
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
body {
|
||||
margin: 0;
|
||||
min-height: 100vh;
|
||||
background: radial-gradient(circle at top left, #182235, var(--bg) 45%);
|
||||
color: var(--text);
|
||||
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif;
|
||||
}
|
||||
|
||||
.topbar {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
gap: 24px;
|
||||
padding: 22px 28px;
|
||||
border-bottom: 1px solid var(--line);
|
||||
background: rgba(12, 16, 23, 0.82);
|
||||
backdrop-filter: blur(14px);
|
||||
}
|
||||
|
||||
h1,
|
||||
h2,
|
||||
p {
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
h1 {
|
||||
font-size: 24px;
|
||||
letter-spacing: 0.02em;
|
||||
}
|
||||
|
||||
h2 {
|
||||
margin-bottom: 14px;
|
||||
font-size: 18px;
|
||||
}
|
||||
|
||||
p {
|
||||
margin-top: 8px;
|
||||
color: var(--muted);
|
||||
}
|
||||
|
||||
.badge {
|
||||
min-width: 86px;
|
||||
padding: 8px 12px;
|
||||
border: 1px solid var(--line);
|
||||
border-radius: 999px;
|
||||
color: var(--yellow);
|
||||
text-align: center;
|
||||
background: var(--panel);
|
||||
}
|
||||
|
||||
.badge.online {
|
||||
color: var(--green);
|
||||
}
|
||||
|
||||
.badge.offline {
|
||||
color: var(--red);
|
||||
}
|
||||
|
||||
.layout {
|
||||
display: grid;
|
||||
grid-template-columns: minmax(0, 1fr) 360px;
|
||||
gap: 18px;
|
||||
padding: 18px;
|
||||
}
|
||||
|
||||
.video-card,
|
||||
.side-card {
|
||||
border: 1px solid var(--line);
|
||||
border-radius: 18px;
|
||||
background: rgba(21, 27, 38, 0.9);
|
||||
box-shadow: 0 18px 40px rgba(0, 0, 0, 0.28);
|
||||
}
|
||||
|
||||
.video-card {
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.pipeline {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 10px;
|
||||
padding: 14px;
|
||||
border-bottom: 1px solid var(--line);
|
||||
overflow-x: auto;
|
||||
}
|
||||
|
||||
.stage {
|
||||
flex: 0 0 auto;
|
||||
padding: 9px 12px;
|
||||
border: 1px solid var(--line);
|
||||
border-radius: 10px;
|
||||
color: var(--muted);
|
||||
background: var(--panel-2);
|
||||
}
|
||||
|
||||
.stage.active {
|
||||
border-color: rgba(46, 232, 135, 0.5);
|
||||
color: var(--green);
|
||||
}
|
||||
|
||||
.arrow {
|
||||
color: var(--muted);
|
||||
}
|
||||
|
||||
.video-wrap {
|
||||
display: grid;
|
||||
place-items: center;
|
||||
min-height: 420px;
|
||||
background: #05070b;
|
||||
}
|
||||
|
||||
#video {
|
||||
display: block;
|
||||
width: 100%;
|
||||
height: auto;
|
||||
max-height: calc(100vh - 190px);
|
||||
object-fit: contain;
|
||||
}
|
||||
|
||||
.side-card {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 22px;
|
||||
padding: 18px;
|
||||
}
|
||||
|
||||
.status-grid {
|
||||
display: grid;
|
||||
grid-template-columns: 72px minmax(0, 1fr);
|
||||
gap: 10px 12px;
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
.status-grid dt {
|
||||
color: var(--muted);
|
||||
}
|
||||
|
||||
.status-grid dd {
|
||||
margin: 0;
|
||||
word-break: break-all;
|
||||
}
|
||||
|
||||
.device-select {
|
||||
width: 100%;
|
||||
min-height: 34px;
|
||||
border: 1px solid var(--line);
|
||||
border-radius: 8px;
|
||||
color: var(--text);
|
||||
background: var(--panel-2);
|
||||
}
|
||||
|
||||
.detections {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 10px;
|
||||
}
|
||||
|
||||
.detections.empty {
|
||||
color: var(--muted);
|
||||
}
|
||||
|
||||
.det-item {
|
||||
padding: 12px;
|
||||
border: 1px solid var(--line);
|
||||
border-radius: 12px;
|
||||
background: var(--panel-2);
|
||||
}
|
||||
|
||||
.det-title {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
margin-bottom: 8px;
|
||||
color: var(--green);
|
||||
font-weight: 700;
|
||||
}
|
||||
|
||||
.det-box {
|
||||
color: var(--muted);
|
||||
font-family: ui-monospace, SFMono-Regular, Menlo, monospace;
|
||||
font-size: 12px;
|
||||
}
|
||||
|
||||
@media (max-width: 980px) {
|
||||
.layout {
|
||||
grid-template-columns: 1fr;
|
||||
}
|
||||
|
||||
.topbar {
|
||||
align-items: flex-start;
|
||||
flex-direction: column;
|
||||
}
|
||||
}
|
||||
221
app/stream_worker.py
Normal file
221
app/stream_worker.py
Normal file
@@ -0,0 +1,221 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Callable
|
||||
|
||||
import cv2
|
||||
|
||||
|
||||
class StreamWorker:
|
||||
def __init__(
|
||||
self,
|
||||
stream_url: str | Callable[[], str],
|
||||
detector: Any,
|
||||
frame_skip: int = 3,
|
||||
jpeg_quality: int = 80,
|
||||
resize_width: int | None = None,
|
||||
):
|
||||
self.stream_url = stream_url
|
||||
self.detector = detector
|
||||
self.frame_skip = max(1, frame_skip)
|
||||
self.jpeg_quality = jpeg_quality
|
||||
self.resize_width = resize_width
|
||||
|
||||
self.lock = threading.Lock()
|
||||
self.latest_jpeg: bytes | None = None
|
||||
self.latest_detections: list[dict[str, Any]] = []
|
||||
self.frame_id = 0
|
||||
self.updated_at = 0.0
|
||||
self.running = False
|
||||
self.connected = False
|
||||
self.error = "尚未启动"
|
||||
self.fps = 0.0
|
||||
self.thread: threading.Thread | None = None
|
||||
self.reconnect_requested = False
|
||||
self.reconnect_version = 0
|
||||
self.resolve_ms = 0.0
|
||||
self.open_ms = 0.0
|
||||
self.first_frame_ms = 0.0
|
||||
|
||||
def start(self) -> None:
|
||||
if self.running:
|
||||
return
|
||||
|
||||
self.running = True
|
||||
self.thread = threading.Thread(target=self._run, daemon=True)
|
||||
self.thread.start()
|
||||
|
||||
def stop(self) -> None:
|
||||
self.running = False
|
||||
if self.thread and self.thread.is_alive():
|
||||
self.thread.join(timeout=2)
|
||||
|
||||
def reconnect(self) -> None:
|
||||
with self.lock:
|
||||
self.latest_jpeg = None
|
||||
self.latest_detections = []
|
||||
self.frame_id = 0
|
||||
self.fps = 0.0
|
||||
self.reconnect_requested = True
|
||||
self.reconnect_version += 1
|
||||
self.connected = False
|
||||
self.error = "正在切换视频源"
|
||||
self.resolve_ms = 0.0
|
||||
self.open_ms = 0.0
|
||||
self.first_frame_ms = 0.0
|
||||
|
||||
def get_jpeg(self) -> bytes | None:
|
||||
with self.lock:
|
||||
return self.latest_jpeg
|
||||
|
||||
def get_snapshot(self) -> dict[str, Any]:
|
||||
with self.lock:
|
||||
return {
|
||||
"frame_id": self.frame_id,
|
||||
"updated_at": self.updated_at,
|
||||
"detections": list(self.latest_detections),
|
||||
"running": self.running,
|
||||
"connected": self.connected,
|
||||
"error": self.error,
|
||||
"fps": round(self.fps, 2),
|
||||
"timings": {
|
||||
"resolve_ms": self.resolve_ms,
|
||||
"open_ms": self.open_ms,
|
||||
"first_frame_ms": self.first_frame_ms,
|
||||
},
|
||||
}
|
||||
|
||||
def _run(self) -> None:
|
||||
cap: cv2.VideoCapture | None = None
|
||||
last_detections: list[dict[str, Any]] = []
|
||||
fps_window_start = time.monotonic()
|
||||
fps_frames = 0
|
||||
|
||||
while self.running:
|
||||
with self.lock:
|
||||
should_reconnect = self.reconnect_requested
|
||||
run_version = self.reconnect_version
|
||||
self.reconnect_requested = False
|
||||
if should_reconnect:
|
||||
# 切换摄像头时必须释放旧连接,否则 OpenCV 会继续阻塞读旧流。
|
||||
if cap is not None:
|
||||
cap.release()
|
||||
cap = None
|
||||
|
||||
if cap is None or not cap.isOpened():
|
||||
started = time.monotonic()
|
||||
stream_url = self.stream_url() if callable(self.stream_url) else self.stream_url
|
||||
resolve_ms = round((time.monotonic() - started) * 1000, 2)
|
||||
started = time.monotonic()
|
||||
cap = cv2.VideoCapture(stream_url)
|
||||
open_ms = round((time.monotonic() - started) * 1000, 2)
|
||||
with self.lock:
|
||||
self.open_ms = open_ms
|
||||
self.resolve_ms = resolve_ms
|
||||
self.first_frame_ms = 0.0
|
||||
if not cap.isOpened():
|
||||
self._set_connection_state(False, "无法打开视频流,2 秒后重试")
|
||||
cap.release()
|
||||
cap = None
|
||||
time.sleep(2)
|
||||
continue
|
||||
self._set_connection_state(True, "已连接")
|
||||
|
||||
started = time.monotonic()
|
||||
ok, frame = cap.read()
|
||||
with self.lock:
|
||||
current_version = self.reconnect_version
|
||||
# 丢弃切换期间从旧连接读到的帧,避免前端画面回跳。
|
||||
if current_version != run_version:
|
||||
continue
|
||||
if self.first_frame_ms == 0.0:
|
||||
with self.lock:
|
||||
self.first_frame_ms = round((time.monotonic() - started) * 1000, 2)
|
||||
if not ok:
|
||||
self._set_connection_state(False, "读取视频帧失败,正在重连")
|
||||
cap.release()
|
||||
cap = None
|
||||
time.sleep(2)
|
||||
continue
|
||||
|
||||
frame = self._resize(frame)
|
||||
self.frame_id += 1
|
||||
|
||||
if self.frame_id % self.frame_skip == 0:
|
||||
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
last_detections = self.detector.detect(frame_rgb)
|
||||
|
||||
annotated = self._draw(frame, last_detections)
|
||||
ok, jpeg = cv2.imencode(
|
||||
".jpg",
|
||||
annotated,
|
||||
[cv2.IMWRITE_JPEG_QUALITY, self.jpeg_quality],
|
||||
)
|
||||
if not ok:
|
||||
continue
|
||||
|
||||
fps_frames += 1
|
||||
now = time.monotonic()
|
||||
if now - fps_window_start >= 1:
|
||||
fps = fps_frames / (now - fps_window_start)
|
||||
fps_window_start = now
|
||||
fps_frames = 0
|
||||
else:
|
||||
fps = self.fps
|
||||
|
||||
with self.lock:
|
||||
current_version = self.reconnect_version
|
||||
if current_version != run_version:
|
||||
continue
|
||||
|
||||
with self.lock:
|
||||
self.latest_jpeg = jpeg.tobytes()
|
||||
self.latest_detections = list(last_detections)
|
||||
self.updated_at = time.time()
|
||||
self.connected = True
|
||||
self.error = ""
|
||||
self.fps = fps
|
||||
|
||||
if cap is not None:
|
||||
cap.release()
|
||||
self._set_connection_state(False, "已停止")
|
||||
|
||||
def _resize(self, frame: Any) -> Any:
|
||||
if not self.resize_width:
|
||||
return frame
|
||||
|
||||
height, width = frame.shape[:2]
|
||||
if width <= self.resize_width:
|
||||
return frame
|
||||
|
||||
scale = self.resize_width / width
|
||||
return cv2.resize(frame, (self.resize_width, int(height * scale)))
|
||||
|
||||
def _draw(self, frame: Any, detections: list[dict[str, Any]]) -> Any:
|
||||
for detection in detections:
|
||||
x1, y1, x2, y2 = detection["box"]
|
||||
label = detection["label"]
|
||||
score = detection["score"]
|
||||
|
||||
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 220, 80), 2)
|
||||
text = f"{label} {score:.2f}"
|
||||
cv2.rectangle(frame, (x1, max(0, y1 - 26)), (x1 + 150, y1), (0, 220, 80), -1)
|
||||
cv2.putText(
|
||||
frame,
|
||||
text,
|
||||
(x1 + 5, max(18, y1 - 7)),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.55,
|
||||
(0, 0, 0),
|
||||
2,
|
||||
cv2.LINE_AA,
|
||||
)
|
||||
|
||||
return frame
|
||||
|
||||
def _set_connection_state(self, connected: bool, error: str) -> None:
|
||||
with self.lock:
|
||||
self.connected = connected
|
||||
self.error = error
|
||||
self.updated_at = time.time()
|
||||
82
app/templates/index.html
Normal file
82
app/templates/index.html
Normal file
@@ -0,0 +1,82 @@
|
||||
<!doctype html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||
<title>DETR 动态打标</title>
|
||||
<link rel="stylesheet" href="/static/style.css" />
|
||||
</head>
|
||||
<body>
|
||||
<header class="topbar">
|
||||
<div>
|
||||
<h1>DETR 动态打标</h1>
|
||||
<p>使用 Python、OpenCV、PyTorch、Transformers DETR 和 FastAPI,Mac mini m2 上运行。</p>
|
||||
</div>
|
||||
<div class="badge" id="connection">连接中</div>
|
||||
</header>
|
||||
|
||||
<main class="layout">
|
||||
<section class="video-card">
|
||||
<div class="pipeline">
|
||||
<div class="stage active">源节点</div>
|
||||
<div class="arrow">→</div>
|
||||
<div class="stage active">DETR 推理</div>
|
||||
<div class="arrow">→</div>
|
||||
<div class="stage active">OSD 打标</div>
|
||||
<div class="arrow">→</div>
|
||||
<div class="stage active">FastAPI 输出</div>
|
||||
</div>
|
||||
<div class="video-wrap">
|
||||
<img id="video" src="/video" alt="动态打标视频流" />
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<aside class="side-card">
|
||||
<section>
|
||||
<h2>运行状态</h2>
|
||||
<dl class="status-grid">
|
||||
<dt>摄像头</dt>
|
||||
<dd>
|
||||
<select id="device-select" class="device-select"></select>
|
||||
</dd>
|
||||
<dt>视频源</dt>
|
||||
<dd id="source">{{ stream_url }}</dd>
|
||||
<dt>模型</dt>
|
||||
<dd>{{ model }}</dd>
|
||||
<dt>设备</dt>
|
||||
<dd>{{ device }}</dd>
|
||||
<dt>帧号</dt>
|
||||
<dd id="frame-id">-</dd>
|
||||
<dt>FPS</dt>
|
||||
<dd id="fps">-</dd>
|
||||
<dt>状态</dt>
|
||||
<dd id="error">等待视频帧</dd>
|
||||
</dl>
|
||||
</section>
|
||||
|
||||
<section>
|
||||
<h2>连接耗时</h2>
|
||||
<dl class="status-grid">
|
||||
<dt>Token</dt>
|
||||
<dd id="timing-token">-</dd>
|
||||
<dt>签名</dt>
|
||||
<dd id="timing-sign">-</dd>
|
||||
<dt>取流地址</dt>
|
||||
<dd id="timing-url">-</dd>
|
||||
<dt>打开流</dt>
|
||||
<dd id="timing-open">-</dd>
|
||||
<dt>首帧</dt>
|
||||
<dd id="timing-frame">-</dd>
|
||||
</dl>
|
||||
</section>
|
||||
|
||||
<section>
|
||||
<h2>检测结果</h2>
|
||||
<div id="detections" class="detections empty">暂无目标</div>
|
||||
</section>
|
||||
</aside>
|
||||
</main>
|
||||
|
||||
<script src="/static/app.js"></script>
|
||||
</body>
|
||||
</html>
|
||||
1
devicelist.env.example
Normal file
1
devicelist.env.example
Normal file
@@ -0,0 +1 @@
|
||||
DEVICE_NUM=33082500001327632958,33102300001327287520
|
||||
6
info.md
Normal file
6
info.md
Normal file
@@ -0,0 +1,6 @@
|
||||
能力开放接口规范: 能力开放接口规范-20260602.docx
|
||||
|
||||
接口地址:https://apicapacity.51iwifi.com
|
||||
appId: a3196a2af16ddc8f93cc
|
||||
appSecret: 7f213e12590c3c0d
|
||||
accessToken:临时生成,会过期
|
||||
11
requirements.txt
Normal file
11
requirements.txt
Normal file
@@ -0,0 +1,11 @@
|
||||
fastapi
|
||||
uvicorn[standard]
|
||||
opencv-python
|
||||
torch
|
||||
torchvision
|
||||
transformers
|
||||
timm
|
||||
pillow
|
||||
numpy
|
||||
python-dotenv
|
||||
jinja2
|
||||
481
tokenizer.md
Normal file
481
tokenizer.md
Normal file
@@ -0,0 +1,481 @@
|
||||
# DETR 的视觉 token 化过程说明
|
||||
|
||||
本文基于当前项目代码 `app/detector.py` 中的实现说明 DETR 的“token 化”过程。
|
||||
|
||||
当前代码使用的是 Hugging Face Transformers:
|
||||
|
||||
```python
|
||||
self.processor = DetrImageProcessor.from_pretrained(model_name)
|
||||
self.model = DetrForObjectDetection.from_pretrained(model_name)
|
||||
```
|
||||
|
||||
默认模型为:
|
||||
|
||||
```text
|
||||
facebook/detr-resnet-50
|
||||
```
|
||||
|
||||
需要注意:DETR 这里没有文本 tokenizer。它处理的是图像,因此所谓“token 化”指的是把图像经过 CNN backbone 后得到的二维视觉特征图,展开成 Transformer 可以处理的一维视觉 token 序列。
|
||||
|
||||
## 当前代码中的入口
|
||||
|
||||
在 `app/detector.py` 中,检测入口是:
|
||||
|
||||
```python
|
||||
image = Image.fromarray(frame_rgb)
|
||||
inputs = self.processor(images=image, return_tensors="pt")
|
||||
inputs = {key: value.to(self.device) for key, value in inputs.items()}
|
||||
|
||||
outputs = self.model(**inputs)
|
||||
```
|
||||
|
||||
这里分成两部分:
|
||||
|
||||
1. `DetrImageProcessor`:做图像预处理。
|
||||
2. `DetrForObjectDetection`:在模型内部完成视觉特征提取、flatten、位置编码、Transformer 编码解码和目标检测。
|
||||
|
||||
## 总体流程
|
||||
|
||||
完整流程可以理解为:
|
||||
|
||||
```text
|
||||
OpenCV RGB 帧
|
||||
↓
|
||||
PIL Image
|
||||
↓
|
||||
DetrImageProcessor 图像预处理
|
||||
↓
|
||||
pixel_values: [batch, 3, H, W]
|
||||
pixel_mask: [batch, H, W]
|
||||
↓
|
||||
ResNet-50 backbone 提取视觉特征
|
||||
↓
|
||||
feature map: [batch, 2048, H', W']
|
||||
↓
|
||||
1×1 convolution 投影通道
|
||||
↓
|
||||
projected feature map: [batch, 256, H', W']
|
||||
↓
|
||||
flatten 空间维度 H' × W'
|
||||
↓
|
||||
visual tokens: [batch, H'×W', 256]
|
||||
↓
|
||||
加入二维位置编码
|
||||
↓
|
||||
Transformer Encoder
|
||||
↓
|
||||
Object Queries + Transformer Decoder
|
||||
↓
|
||||
类别 logits + 边界框 boxes
|
||||
↓
|
||||
post_process_object_detection 还原到原图坐标
|
||||
```
|
||||
|
||||
## 第 1 步:图像预处理
|
||||
|
||||
代码:
|
||||
|
||||
```python
|
||||
inputs = self.processor(images=image, return_tensors="pt")
|
||||
```
|
||||
|
||||
`DetrImageProcessor` 主要做这些事情:
|
||||
|
||||
- 调整图像尺寸。
|
||||
- 转换为 PyTorch tensor。
|
||||
- 归一化像素值。
|
||||
- 生成 `pixel_values`。
|
||||
- 必要时生成 `pixel_mask`,用于标记 padding 区域。
|
||||
|
||||
输出通常包含:
|
||||
|
||||
```python
|
||||
{
|
||||
"pixel_values": Tensor[batch, 3, H, W],
|
||||
"pixel_mask": Tensor[batch, H, W]
|
||||
}
|
||||
```
|
||||
|
||||
其中:
|
||||
|
||||
- `pixel_values` 是送入模型的图像张量。
|
||||
- `pixel_mask` 用于告诉模型哪些区域是真实图像,哪些区域是 padding。
|
||||
|
||||
这一步不是文本 token 化,不会产生 `input_ids`、`attention_mask` 这类 NLP tokenizer 输出。
|
||||
|
||||
## 第 2 步:ResNet-50 提取特征图
|
||||
|
||||
模型内部首先使用 ResNet-50 backbone 处理图像:
|
||||
|
||||
```text
|
||||
pixel_values: [batch, 3, H, W]
|
||||
↓ ResNet-50
|
||||
feature map: [batch, 2048, H', W']
|
||||
```
|
||||
|
||||
`H'` 和 `W'` 是下采样后的空间尺寸。ResNet 通常会把图像下采样约 32 倍。
|
||||
|
||||
例如输入图像尺寸为:
|
||||
|
||||
```text
|
||||
[batch, 3, 800, 1280]
|
||||
```
|
||||
|
||||
经过 ResNet 后,特征图可能接近:
|
||||
|
||||
```text
|
||||
[batch, 2048, 25, 40]
|
||||
```
|
||||
|
||||
这里的每个空间位置 `(y, x)` 都是一个高层视觉特征向量。
|
||||
|
||||
## 第 3 步:1×1 卷积投影通道
|
||||
|
||||
ResNet 输出通道数通常是 `2048`,而 DETR Transformer 默认隐藏维度通常是 `256`。
|
||||
|
||||
因此模型会使用一个 `1×1 convolution` 做通道投影:
|
||||
|
||||
```text
|
||||
[batch, 2048, H', W']
|
||||
↓ 1×1 conv
|
||||
[batch, 256, H', W']
|
||||
```
|
||||
|
||||
这一步不会改变空间大小,只改变每个空间位置的特征维度。
|
||||
|
||||
可以理解为:
|
||||
|
||||
```text
|
||||
每个网格点的 2048 维向量 → 256 维向量
|
||||
```
|
||||
|
||||
## 第 4 步:flatten 成视觉 token 序列
|
||||
|
||||
这是“视觉 token 化”的核心步骤。
|
||||
|
||||
投影后的特征图形状是:
|
||||
|
||||
```text
|
||||
[batch, 256, H', W']
|
||||
```
|
||||
|
||||
模型会把二维空间维度 `H' × W'` 展开成一维序列:
|
||||
|
||||
```text
|
||||
[batch, 256, H', W']
|
||||
↓ flatten H' 和 W'
|
||||
[batch, 256, H'×W']
|
||||
↓ 调整维度顺序
|
||||
[batch, H'×W', 256]
|
||||
```
|
||||
|
||||
如果特征图是:
|
||||
|
||||
```text
|
||||
[batch, 256, 25, 40]
|
||||
```
|
||||
|
||||
那么 token 数是:
|
||||
|
||||
```text
|
||||
25 × 40 = 1000
|
||||
```
|
||||
|
||||
最终得到:
|
||||
|
||||
```text
|
||||
[batch, 1000, 256]
|
||||
```
|
||||
|
||||
也就是说:
|
||||
|
||||
```text
|
||||
每个特征图网格位置 = 1 个视觉 token
|
||||
每个视觉 token = 1 个 256 维向量
|
||||
```
|
||||
|
||||
伪代码可以写成:
|
||||
|
||||
```python
|
||||
# x: [batch, 256, h, w]
|
||||
x = x.flatten(2) # [batch, 256, h*w]
|
||||
x = x.transpose(1, 2) # [batch, h*w, 256]
|
||||
```
|
||||
|
||||
原始 DETR 论文和部分 PyTorch 实现中也常见 sequence-first 格式:
|
||||
|
||||
```python
|
||||
# x: [batch, 256, h, w]
|
||||
x = x.flatten(2) # [batch, 256, h*w]
|
||||
x = x.permute(2, 0, 1) # [h*w, batch, 256]
|
||||
```
|
||||
|
||||
两种写法本质相同,只是 Transformer 接口期望的维度顺序不同。
|
||||
|
||||
## 第 5 步:加入二维位置编码
|
||||
|
||||
Transformer 本身不理解图像中的二维空间位置。
|
||||
|
||||
flatten 后,模型只看到一串 token:
|
||||
|
||||
```text
|
||||
token_0, token_1, token_2, ..., token_N
|
||||
```
|
||||
|
||||
如果不加入位置编码,模型不知道某个 token 原来位于图像左上角、中心还是右下角。
|
||||
|
||||
因此 DETR 会为特征图每个 `(y, x)` 位置生成二维位置编码:
|
||||
|
||||
```text
|
||||
position encoding: [batch, 256, H', W']
|
||||
```
|
||||
|
||||
然后同样 flatten:
|
||||
|
||||
```text
|
||||
[batch, 256, H', W']
|
||||
↓
|
||||
[batch, H'×W', 256]
|
||||
```
|
||||
|
||||
Transformer Encoder 接收的是:
|
||||
|
||||
```text
|
||||
visual token + positional encoding
|
||||
```
|
||||
|
||||
位置编码让模型知道 token 的空间布局。
|
||||
|
||||
## 第 6 步:Transformer Encoder 处理视觉 token
|
||||
|
||||
经过 flatten 和位置编码后,视觉 token 序列进入 Transformer Encoder:
|
||||
|
||||
```text
|
||||
[batch, H'×W', 256]
|
||||
↓ Transformer Encoder
|
||||
[batch, H'×W', 256]
|
||||
```
|
||||
|
||||
Encoder 会通过 self-attention 建模图像中不同区域之间的关系。
|
||||
|
||||
例如:
|
||||
|
||||
- 车头区域可以关注车身区域。
|
||||
- 道路区域可以影响车辆判断。
|
||||
- 远处小目标可以和周围上下文一起被理解。
|
||||
|
||||
## 第 7 步:Object Queries 和 Transformer Decoder
|
||||
|
||||
DETR 与传统检测器不同,它不是先生成大量 anchor box。
|
||||
|
||||
它使用一组可学习的 object queries。常见数量是:
|
||||
|
||||
```text
|
||||
100 个 object queries
|
||||
```
|
||||
|
||||
这些 queries 进入 Transformer Decoder,并关注 Encoder 输出的视觉 token:
|
||||
|
||||
```text
|
||||
object queries: [batch, 100, 256]
|
||||
encoder tokens: [batch, H'×W', 256]
|
||||
↓ Transformer Decoder
|
||||
object features: [batch, 100, 256]
|
||||
```
|
||||
|
||||
每个 query 最终预测一个候选目标:
|
||||
|
||||
```text
|
||||
类别 + 边界框
|
||||
```
|
||||
|
||||
因此 DETR 输出通常可以理解为:
|
||||
|
||||
```text
|
||||
最多 100 个候选目标
|
||||
```
|
||||
|
||||
每个候选目标会包含:
|
||||
|
||||
- 类别 logits。
|
||||
- 归一化边界框。
|
||||
|
||||
## 第 8 步:后处理成车辆检测结果
|
||||
|
||||
当前代码中的后处理是:
|
||||
|
||||
```python
|
||||
target_sizes = torch.tensor([image.size[::-1]], device=self.device)
|
||||
results = self.processor.post_process_object_detection(
|
||||
outputs,
|
||||
target_sizes=target_sizes,
|
||||
threshold=self.confidence,
|
||||
)[0]
|
||||
```
|
||||
|
||||
这一步会:
|
||||
|
||||
- 把模型输出的归一化框还原为原图坐标。
|
||||
- 根据置信度阈值过滤低分检测。
|
||||
- 返回 `scores`、`labels`、`boxes`。
|
||||
|
||||
然后代码过滤车辆类别:
|
||||
|
||||
```python
|
||||
label_name = self.model.config.id2label[label.item()]
|
||||
if label_name not in self.vehicle_labels:
|
||||
continue
|
||||
```
|
||||
|
||||
默认车辆类别为:
|
||||
|
||||
```text
|
||||
car, motorcycle, bus, truck, bicycle
|
||||
```
|
||||
|
||||
最终输出格式:
|
||||
|
||||
```python
|
||||
{
|
||||
"label": "car",
|
||||
"score": 0.9132,
|
||||
"box": [x1, y1, x2, y2]
|
||||
}
|
||||
```
|
||||
|
||||
## 与文本 tokenizer 的区别
|
||||
|
||||
| 对比项 | 文本 tokenizer | DETR 视觉 token 化 |
|
||||
| --- | --- | --- |
|
||||
| 输入 | 文本字符串 | 图像张量 |
|
||||
| 输出 | token id 序列 | 视觉特征向量序列 |
|
||||
| token 来源 | 词、子词、字符片段 | CNN 特征图空间网格 |
|
||||
| token 内容 | 离散整数 id | 连续浮点向量 |
|
||||
| 位置编码 | 一维位置编码 | 二维图像位置编码 |
|
||||
| 当前代码中的类 | 无 | `DetrImageProcessor` + `DetrForObjectDetection` |
|
||||
|
||||
## 为什么说它是 patch-like features
|
||||
|
||||
它们可以叫做 `patch-like features`,但不是 ViT 那种直接切原图 patch。
|
||||
|
||||
ViT 通常是:
|
||||
|
||||
```text
|
||||
原图 -> 切成 16×16 patch -> 线性投影 -> token
|
||||
```
|
||||
|
||||
DETR-ResNet 是:
|
||||
|
||||
```text
|
||||
原图 -> ResNet 特征图 -> 每个特征图网格点 -> token
|
||||
```
|
||||
|
||||
因此 DETR 的每个 token:
|
||||
|
||||
- 对应特征图上的一个空间位置。
|
||||
- 感受野来自 ResNet 深层网络。
|
||||
- 通常覆盖原图中一片较大的区域。
|
||||
- 相邻 token 的感受野会重叠。
|
||||
|
||||
## 一个具体尺寸例子
|
||||
|
||||
假设预处理后图像大小为:
|
||||
|
||||
```text
|
||||
800 × 1280
|
||||
```
|
||||
|
||||
ResNet-50 下采样约 32 倍:
|
||||
|
||||
```text
|
||||
H' = 800 / 32 ≈ 25
|
||||
W' = 1280 / 32 ≈ 40
|
||||
```
|
||||
|
||||
Backbone 输出:
|
||||
|
||||
```text
|
||||
[batch, 2048, 25, 40]
|
||||
```
|
||||
|
||||
1×1 卷积投影:
|
||||
|
||||
```text
|
||||
[batch, 256, 25, 40]
|
||||
```
|
||||
|
||||
flatten:
|
||||
|
||||
```text
|
||||
[batch, 1000, 256]
|
||||
```
|
||||
|
||||
所以该图像大约会生成:
|
||||
|
||||
```text
|
||||
1000 个视觉 token
|
||||
```
|
||||
|
||||
每个 token 是:
|
||||
|
||||
```text
|
||||
256 维连续向量
|
||||
```
|
||||
|
||||
## 当前项目中的实际检测路径
|
||||
|
||||
当前项目的整体路径是:
|
||||
|
||||
```text
|
||||
RTSP/HLS 视频帧
|
||||
↓
|
||||
OpenCV 读取 BGR frame
|
||||
↓
|
||||
cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
↓
|
||||
Image.fromarray(frame_rgb)
|
||||
↓
|
||||
DetrImageProcessor 图像预处理
|
||||
↓
|
||||
DetrForObjectDetection 内部完成视觉 token 化和目标检测
|
||||
↓
|
||||
post_process_object_detection
|
||||
↓
|
||||
过滤车辆类别
|
||||
↓
|
||||
OpenCV OSD 画框
|
||||
↓
|
||||
FastAPI /video 输出动态打标画面
|
||||
```
|
||||
|
||||
对应代码位置:
|
||||
|
||||
- `app/detector.py`:DETR 图像预处理、模型推理、后处理。
|
||||
- `app/stream_worker.py`:视频帧读取、推理调用、OSD 画框。
|
||||
- `app/main.py`:MJPEG 视频流和检测结果接口。
|
||||
|
||||
## 小结
|
||||
|
||||
当前项目中的 DETR 不包含文本 tokenizer。
|
||||
|
||||
它的视觉 token 化具体是:
|
||||
|
||||
```text
|
||||
图像经过 ResNet-50 得到二维特征图
|
||||
↓
|
||||
1×1 卷积把通道投影到 Transformer hidden size
|
||||
↓
|
||||
把 H' × W' 个空间位置 flatten 成 H'×W' 个视觉 token
|
||||
↓
|
||||
给每个 token 加二维位置编码
|
||||
↓
|
||||
送入 Transformer Encoder
|
||||
```
|
||||
|
||||
因此,“token 数”主要由输入分辨率和 backbone 下采样比例决定:
|
||||
|
||||
```text
|
||||
视觉 token 数 ≈ (输入高度 / 32) × (输入宽度 / 32)
|
||||
```
|
||||
|
||||
实际数量会受到图像预处理 resize、padding 和特征图尺寸取整影响。
|
||||
BIN
能力开放接口规范-20260602.docx
Normal file
BIN
能力开放接口规范-20260602.docx
Normal file
Binary file not shown.
Reference in New Issue
Block a user