166 lines
5.2 KiB
Python
166 lines
5.2 KiB
Python
"""FastAPI 入口模块:提供页面、视频流、状态、设备切换和 WebSocket 接口。"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
|
|
from fastapi import FastAPI, HTTPException, Request, WebSocket, WebSocketDisconnect
|
|
from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse
|
|
from fastapi.staticfiles import StaticFiles
|
|
from fastapi.templating import Jinja2Templates
|
|
|
|
from app.capacity_api import CapacityApiClient
|
|
from app.config import load_settings, mask_url
|
|
from app.detector import DetrVehicleDetector
|
|
from app.device_manager import DeviceManager
|
|
from app.stream_worker import StreamWorker
|
|
|
|
settings = load_settings()
|
|
api_client = CapacityApiClient(
|
|
base_url=settings.api_base_url,
|
|
app_id=settings.app_id,
|
|
app_secret=settings.app_secret,
|
|
account=settings.device_account,
|
|
method=settings.stream_method,
|
|
)
|
|
detector = DetrVehicleDetector(
|
|
model_name=settings.detr_model,
|
|
confidence=settings.confidence,
|
|
vehicle_labels=settings.vehicle_labels,
|
|
)
|
|
device_manager = DeviceManager(
|
|
path=settings.device_list_path,
|
|
api_client=api_client,
|
|
fallback_url=settings.stream_url,
|
|
)
|
|
def resolve_stream_url() -> str:
|
|
return device_manager.resolve_stream_url()
|
|
|
|
|
|
worker = StreamWorker(
|
|
stream_url=resolve_stream_url,
|
|
detector=detector,
|
|
frame_skip=settings.frame_skip,
|
|
jpeg_quality=settings.jpeg_quality,
|
|
resize_width=settings.resize_width,
|
|
)
|
|
|
|
app = FastAPI(title="DETR 动态打标")
|
|
app.mount("/static", StaticFiles(directory="app/static"), name="static")
|
|
templates = Jinja2Templates(directory="app/templates")
|
|
|
|
|
|
def display_model_name(model_name: str) -> str:
|
|
return model_name.rsplit("/", 1)[-1]
|
|
|
|
|
|
@app.on_event("startup")
|
|
def startup() -> None:
|
|
worker.start()
|
|
|
|
|
|
@app.on_event("shutdown")
|
|
def shutdown() -> None:
|
|
worker.stop()
|
|
|
|
|
|
@app.get("/", response_class=HTMLResponse)
|
|
def index(request: Request) -> HTMLResponse:
|
|
return templates.TemplateResponse(
|
|
"index.html",
|
|
{
|
|
"request": request,
|
|
"model": display_model_name(settings.detr_model),
|
|
"device": detector.device_name,
|
|
"stream_url": f"设备号:{device_manager.get_snapshot()['current_device_num']}",
|
|
},
|
|
)
|
|
|
|
|
|
@app.get("/video")
|
|
def video() -> StreamingResponse:
|
|
async def generate():
|
|
while True:
|
|
frame = worker.get_jpeg()
|
|
if frame is None:
|
|
await asyncio.sleep(0.1)
|
|
continue
|
|
|
|
yield b"--frame\r\nContent-Type: image/jpeg\r\n\r\n" + frame + b"\r\n"
|
|
await asyncio.sleep(0.03)
|
|
|
|
return StreamingResponse(
|
|
generate(),
|
|
media_type="multipart/x-mixed-replace; boundary=frame",
|
|
)
|
|
|
|
|
|
@app.get("/detections")
|
|
def detections() -> JSONResponse:
|
|
snapshot = worker.get_snapshot()
|
|
return JSONResponse(
|
|
{
|
|
"frame_id": snapshot["frame_id"],
|
|
"updated_at": snapshot["updated_at"],
|
|
"detections": snapshot["detections"],
|
|
}
|
|
)
|
|
|
|
|
|
@app.get("/status")
|
|
def status() -> JSONResponse:
|
|
snapshot = worker.get_snapshot()
|
|
device_snapshot = device_manager.get_snapshot()
|
|
timings = dict(device_snapshot["source_timings"])
|
|
# 合并取流地址和 OpenCV 读流耗时,前端按同一个 timings 对象展示。
|
|
timings.update(snapshot["timings"])
|
|
return JSONResponse(
|
|
{
|
|
"running": snapshot["running"],
|
|
"connected": snapshot["connected"],
|
|
"frame_id": snapshot["frame_id"],
|
|
"updated_at": snapshot["updated_at"],
|
|
"fps": snapshot["fps"],
|
|
"error": snapshot["error"],
|
|
"source": mask_url(device_snapshot["current_url"]) if device_snapshot["current_url"] else "等待获取播放地址",
|
|
"model": display_model_name(settings.detr_model),
|
|
"device": detector.device_name,
|
|
"frame_skip": settings.frame_skip,
|
|
"confidence": settings.confidence,
|
|
"devices": device_snapshot["devices"],
|
|
"current_device_num": device_snapshot["current_device_num"],
|
|
"timings": timings,
|
|
}
|
|
)
|
|
|
|
|
|
@app.post("/devices/{device_num}")
|
|
def switch_device(device_num: str) -> JSONResponse:
|
|
try:
|
|
device_manager.set_current_device(device_num)
|
|
except ValueError as exc:
|
|
raise HTTPException(status_code=404, detail=str(exc)) from exc
|
|
worker.reconnect()
|
|
return JSONResponse({"current_device_num": device_num})
|
|
|
|
|
|
@app.websocket("/ws/detections")
|
|
async def websocket_detections(websocket: WebSocket) -> None:
|
|
await websocket.accept()
|
|
try:
|
|
while True:
|
|
data = worker.get_snapshot()
|
|
device_snapshot = device_manager.get_snapshot()
|
|
data.update(
|
|
{
|
|
"devices": device_snapshot["devices"],
|
|
"current_device_num": device_snapshot["current_device_num"],
|
|
"source": mask_url(device_snapshot["current_url"]) if device_snapshot["current_url"] else "等待获取播放地址",
|
|
"timings": {**device_snapshot["source_timings"], **data["timings"]},
|
|
}
|
|
)
|
|
await websocket.send_json(data)
|
|
await asyncio.sleep(0.5)
|
|
except WebSocketDisconnect:
|
|
return
|