Files
tokenresearch/app/main.py
2026-06-03 11:04:16 +08:00

166 lines
5.2 KiB
Python

"""FastAPI 入口模块:提供页面、视频流、状态、设备切换和 WebSocket 接口。"""
from __future__ import annotations
import asyncio
from fastapi import FastAPI, HTTPException, Request, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from app.capacity_api import CapacityApiClient
from app.config import load_settings, mask_url
from app.detector import DetrVehicleDetector
from app.device_manager import DeviceManager
from app.stream_worker import StreamWorker
settings = load_settings()
api_client = CapacityApiClient(
base_url=settings.api_base_url,
app_id=settings.app_id,
app_secret=settings.app_secret,
account=settings.device_account,
method=settings.stream_method,
)
detector = DetrVehicleDetector(
model_name=settings.detr_model,
confidence=settings.confidence,
vehicle_labels=settings.vehicle_labels,
)
device_manager = DeviceManager(
path=settings.device_list_path,
api_client=api_client,
fallback_url=settings.stream_url,
)
def resolve_stream_url() -> str:
return device_manager.resolve_stream_url()
worker = StreamWorker(
stream_url=resolve_stream_url,
detector=detector,
frame_skip=settings.frame_skip,
jpeg_quality=settings.jpeg_quality,
resize_width=settings.resize_width,
)
app = FastAPI(title="DETR 动态打标")
app.mount("/static", StaticFiles(directory="app/static"), name="static")
templates = Jinja2Templates(directory="app/templates")
def display_model_name(model_name: str) -> str:
return model_name.rsplit("/", 1)[-1]
@app.on_event("startup")
def startup() -> None:
worker.start()
@app.on_event("shutdown")
def shutdown() -> None:
worker.stop()
@app.get("/", response_class=HTMLResponse)
def index(request: Request) -> HTMLResponse:
return templates.TemplateResponse(
"index.html",
{
"request": request,
"model": display_model_name(settings.detr_model),
"device": detector.device_name,
"stream_url": f"设备号:{device_manager.get_snapshot()['current_device_num']}",
},
)
@app.get("/video")
def video() -> StreamingResponse:
async def generate():
while True:
frame = worker.get_jpeg()
if frame is None:
await asyncio.sleep(0.1)
continue
yield b"--frame\r\nContent-Type: image/jpeg\r\n\r\n" + frame + b"\r\n"
await asyncio.sleep(0.03)
return StreamingResponse(
generate(),
media_type="multipart/x-mixed-replace; boundary=frame",
)
@app.get("/detections")
def detections() -> JSONResponse:
snapshot = worker.get_snapshot()
return JSONResponse(
{
"frame_id": snapshot["frame_id"],
"updated_at": snapshot["updated_at"],
"detections": snapshot["detections"],
}
)
@app.get("/status")
def status() -> JSONResponse:
snapshot = worker.get_snapshot()
device_snapshot = device_manager.get_snapshot()
timings = dict(device_snapshot["source_timings"])
# 合并取流地址和 OpenCV 读流耗时,前端按同一个 timings 对象展示。
timings.update(snapshot["timings"])
return JSONResponse(
{
"running": snapshot["running"],
"connected": snapshot["connected"],
"frame_id": snapshot["frame_id"],
"updated_at": snapshot["updated_at"],
"fps": snapshot["fps"],
"error": snapshot["error"],
"source": mask_url(device_snapshot["current_url"]) if device_snapshot["current_url"] else "等待获取播放地址",
"model": display_model_name(settings.detr_model),
"device": detector.device_name,
"frame_skip": settings.frame_skip,
"confidence": settings.confidence,
"devices": device_snapshot["devices"],
"current_device_num": device_snapshot["current_device_num"],
"timings": timings,
}
)
@app.post("/devices/{device_num}")
def switch_device(device_num: str) -> JSONResponse:
try:
device_manager.set_current_device(device_num)
except ValueError as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc
worker.reconnect()
return JSONResponse({"current_device_num": device_num})
@app.websocket("/ws/detections")
async def websocket_detections(websocket: WebSocket) -> None:
await websocket.accept()
try:
while True:
data = worker.get_snapshot()
device_snapshot = device_manager.get_snapshot()
data.update(
{
"devices": device_snapshot["devices"],
"current_device_num": device_snapshot["current_device_num"],
"source": mask_url(device_snapshot["current_url"]) if device_snapshot["current_url"] else "等待获取播放地址",
"timings": {**device_snapshot["source_timings"], **data["timings"]},
}
)
await websocket.send_json(data)
await asyncio.sleep(0.5)
except WebSocketDisconnect:
return