Files
digit-cracker/scripts/benchmark_concurrent.py
2025-10-30 16:31:47 +08:00

447 lines
14 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
并发性能测试脚本 - 模拟多用户同时识别
功能:
- 模拟 n 个用户并发执行数字识别任务
- 循环使用 valid 文件夹中的图片
- 统计总体性能指标(吞吐量、响应时间、成功率)
- 生成详细的性能报告
使用方法:
# 模拟10个并发用户每个用户识别20张图片
python scripts/benchmark_concurrent.py --users 10 --images-per-user 20
# 使用不同的模型和配置
python scripts/benchmark_concurrent.py --users 5 --images-per-user 10 --conf 0.15
# 指定输出报告文件
python scripts/benchmark_concurrent.py --users 20 --output results/benchmark_report.txt
输出:
- 实时进度显示
- 详细的性能统计总时间、平均响应时间、QPS等
- 每个用户的执行情况
- 失败任务的详细信息
作者: Gavin Chan
日期: 2025-01-30
"""
import argparse
import time
import os
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
from ultralytics import YOLO
import cv2
import numpy as np
from datetime import datetime
import json
def parse_args():
"""解析命令行参数"""
parser = argparse.ArgumentParser(description='并发性能测试 - 模拟多用户识别')
parser.add_argument('--users', type=int, default=10,
help='并发用户数量(默认: 10')
parser.add_argument('--images-per-user', type=int, default=20,
help='每个用户识别的图片数量(默认: 20')
parser.add_argument('--model', type=str,
default='runs/digit_yolo/exp_preprocessed_color_150/weights/best.pt',
help='模型路径')
parser.add_argument('--source', type=str, default='valid',
help='图片源文件夹(默认: valid')
parser.add_argument('--conf', type=float, default=0.2,
help='置信度阈值(默认: 0.2')
parser.add_argument('--imgsz', type=int, default=320,
help='输入图片大小(默认: 320')
parser.add_argument('--output', type=str, default='results/benchmark_report.txt',
help='输出报告文件路径')
parser.add_argument('--verbose', action='store_true',
help='显示详细日志')
return parser.parse_args()
def load_image_paths(source_dir):
"""
加载所有图片路径
Args:
source_dir: 图片文件夹路径
Returns:
list: 图片路径列表
"""
source_path = Path(source_dir)
if not source_path.exists():
raise FileNotFoundError(f"图片文件夹不存在: {source_dir}")
# 支持的图片格式
extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.webp')
image_paths = [str(p) for p in source_path.iterdir()
if p.suffix.lower() in extensions]
if not image_paths:
raise ValueError(f"文件夹中没有找到图片: {source_dir}")
return sorted(image_paths)
def extract_digits_from_predictions(results):
"""
从YOLO预测结果中提取数字
Args:
results: YOLO预测结果对象
Returns:
tuple: (识别的数字字符串, 平均置信度, 检测到的数字个数)
"""
if not results or len(results) == 0:
return "", 0.0, 0
result = results[0]
if result.boxes is None or len(result.boxes) == 0:
return "", 0.0, 0
# 获取所有检测框
boxes = result.boxes.cpu().numpy()
# 提取类别、置信度和位置
detections = []
for box in boxes:
cls = int(box.cls[0])
conf = float(box.conf[0])
x_center = float((box.xyxy[0][0] + box.xyxy[0][2]) / 2)
detections.append((cls, conf, x_center))
if not detections:
return "", 0.0, 0
# 按x坐标排序从左到右
detections.sort(key=lambda x: x[2])
# 提取数字和置信度
digits = [str(d[0]) for d in detections]
confidences = [d[1] for d in detections]
result_str = ''.join(digits)
avg_conf = sum(confidences) / len(confidences)
return result_str, avg_conf, len(detections)
def recognize_single_image(model, image_path, conf_threshold, img_size):
"""
识别单张图片
Args:
model: YOLO模型对象
image_path: 图片路径
conf_threshold: 置信度阈值
img_size: 输入图片大小
Returns:
dict: 识别结果 {
'filename': 文件名,
'digits': 识别的数字,
'confidence': 置信度,
'count': 数字个数,
'time': 耗时(秒),
'success': 是否成功
}
"""
start_time = time.time()
try:
# 执行预测
results = model.predict(
source=image_path,
conf=conf_threshold,
imgsz=img_size,
verbose=False
)
# 提取数字
digits, confidence, count = extract_digits_from_predictions(results)
elapsed = time.time() - start_time
return {
'filename': Path(image_path).name,
'digits': digits,
'confidence': confidence,
'count': count,
'time': elapsed,
'success': True,
'error': None
}
except Exception as e:
elapsed = time.time() - start_time
return {
'filename': Path(image_path).name,
'digits': '',
'confidence': 0.0,
'count': 0,
'time': elapsed,
'success': False,
'error': str(e)
}
def user_task(user_id, model_path, image_paths, num_images, conf_threshold, img_size, verbose):
"""
单个用户的识别任务
Args:
user_id: 用户ID
model_path: 模型路径
image_paths: 所有可用图片路径
num_images: 该用户要识别的图片数量
conf_threshold: 置信度阈值
img_size: 图片大小
verbose: 是否显示详细日志
Returns:
dict: 用户任务结果 {
'user_id': 用户ID,
'results': 识别结果列表,
'total_time': 总耗时,
'success_count': 成功数量,
'avg_time': 平均耗时
}
"""
# 加载模型(每个线程独立加载)
model = YOLO(model_path)
# 循环使用图片
user_images = []
for i in range(num_images):
img_idx = (user_id * num_images + i) % len(image_paths)
user_images.append(image_paths[img_idx])
start_time = time.time()
results = []
for img_path in user_images:
result = recognize_single_image(model, img_path, conf_threshold, img_size)
results.append(result)
if verbose:
status = "" if result['success'] else ""
print(f" [{status}] 用户{user_id:2d} | {result['filename']:20s} | "
f"{result['digits']:6s} | {result['time']:.3f}s")
total_time = time.time() - start_time
success_count = sum(1 for r in results if r['success'])
avg_time = sum(r['time'] for r in results) / len(results) if results else 0
return {
'user_id': user_id,
'results': results,
'total_time': total_time,
'success_count': success_count,
'avg_time': avg_time
}
def generate_report(user_results, total_time, args):
"""
生成性能测试报告
Args:
user_results: 所有用户的结果列表
total_time: 总执行时间
args: 命令行参数
Returns:
str: 报告内容
"""
# 统计数据
total_images = sum(len(ur['results']) for ur in user_results)
total_success = sum(ur['success_count'] for ur in user_results)
total_failed = total_images - total_success
success_rate = (total_success / total_images * 100) if total_images > 0 else 0
# 响应时间统计
all_times = [r['time'] for ur in user_results for r in ur['results']]
avg_response_time = sum(all_times) / len(all_times) if all_times else 0
min_response_time = min(all_times) if all_times else 0
max_response_time = max(all_times) if all_times else 0
# 吞吐量
qps = total_images / total_time if total_time > 0 else 0
# 生成报告
report = []
report.append("=" * 80)
report.append("YOLO 数字识别并发性能测试报告")
report.append("=" * 80)
report.append(f"\n测试时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
report.append(f"\n配置参数:")
report.append(f" - 并发用户数: {args.users}")
report.append(f" - 每用户图片数: {args.images_per_user}")
report.append(f" - 模型: {args.model}")
report.append(f" - 置信度阈值: {args.conf}")
report.append(f" - 图片大小: {args.imgsz}")
report.append(f" - 图片源: {args.source}")
report.append(f"\n总体性能:")
report.append(f" - 总执行时间: {total_time:.2f}")
report.append(f" - 总识别图片数: {total_images}")
report.append(f" - 成功: {total_success} ({success_rate:.1f}%)")
report.append(f" - 失败: {total_failed}")
report.append(f" - 吞吐量 (QPS): {qps:.2f} 图片/秒")
report.append(f"\n响应时间统计:")
report.append(f" - 平均响应时间: {avg_response_time:.3f}")
report.append(f" - 最小响应时间: {min_response_time:.3f}")
report.append(f" - 最大响应时间: {max_response_time:.3f}")
# 百分位数
sorted_times = sorted(all_times)
p50_idx = int(len(sorted_times) * 0.50)
p90_idx = int(len(sorted_times) * 0.90)
p95_idx = int(len(sorted_times) * 0.95)
p99_idx = int(len(sorted_times) * 0.99)
report.append(f" - P50 响应时间: {sorted_times[p50_idx]:.3f}")
report.append(f" - P90 响应时间: {sorted_times[p90_idx]:.3f}")
report.append(f" - P95 响应时间: {sorted_times[p95_idx]:.3f}")
report.append(f" - P99 响应时间: {sorted_times[p99_idx]:.3f}")
# 每个用户的统计
report.append(f"\n各用户性能:")
report.append(f" {'用户ID':>8} | {'图片数':>8} | {'成功':>8} | {'总耗时':>10} | {'平均耗时':>10}")
report.append(f" {'-'*8}-+-{'-'*8}-+-{'-'*8}-+-{'-'*10}-+-{'-'*10}")
for ur in sorted(user_results, key=lambda x: x['user_id']):
report.append(f" {ur['user_id']:8d} | {len(ur['results']):8d} | "
f"{ur['success_count']:8d} | {ur['total_time']:9.2f}s | "
f"{ur['avg_time']:9.3f}s")
# 失败任务详情
failed_tasks = []
for ur in user_results:
for r in ur['results']:
if not r['success']:
failed_tasks.append((ur['user_id'], r))
if failed_tasks:
report.append(f"\n失败任务详情 (共 {len(failed_tasks)} 个):")
for user_id, result in failed_tasks[:20]: # 最多显示20个
report.append(f" 用户{user_id} | {result['filename']} | 错误: {result['error']}")
if len(failed_tasks) > 20:
report.append(f" ... 还有 {len(failed_tasks) - 20} 个失败任务")
report.append("\n" + "=" * 80)
return '\n'.join(report)
def main():
"""主函数"""
args = parse_args()
# 检查模型文件
if not Path(args.model).exists():
print(f"❌ 模型文件不存在: {args.model}")
return
# 加载图片路径
print(f"\n📂 加载图片路径...")
try:
image_paths = load_image_paths(args.source)
print(f"✓ 找到 {len(image_paths)} 张图片")
except Exception as e:
print(f"❌ 加载图片失败: {e}")
return
# 显示测试配置
total_images = args.users * args.images_per_user
print(f"\n⚙️ 测试配置:")
print(f" - 并发用户数: {args.users}")
print(f" - 每用户图片数: {args.images_per_user}")
print(f" - 总图片数: {total_images}")
print(f" - 模型: {Path(args.model).name}")
print(f" - 置信度: {args.conf}")
print(f" - 图片大小: {args.imgsz}")
# 开始测试
print(f"\n🚀 开始并发测试...\n")
start_time = time.time()
# 使用线程池并发执行
user_results = []
with ThreadPoolExecutor(max_workers=args.users) as executor:
# 提交所有用户任务
futures = []
for user_id in range(args.users):
future = executor.submit(
user_task,
user_id,
args.model,
image_paths,
args.images_per_user,
args.conf,
args.imgsz,
args.verbose
)
futures.append(future)
# 收集结果并显示进度
completed = 0
for future in as_completed(futures):
result = future.result()
user_results.append(result)
completed += 1
if not args.verbose:
print(f" 进度: {completed}/{args.users} 用户完成 "
f"({completed/args.users*100:.0f}%)", end='\r')
total_time = time.time() - start_time
if not args.verbose:
print() # 换行
# 生成报告
print(f"\n📊 生成性能报告...")
report = generate_report(user_results, total_time, args)
# 显示报告
print(report)
# 保存报告
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'w', encoding='utf-8') as f:
f.write(report)
print(f"\n✓ 报告已保存到: {args.output}")
# 保存JSON格式的详细数据
json_output = output_path.with_suffix('.json')
json_data = {
'config': vars(args),
'summary': {
'total_time': total_time,
'total_images': total_images,
'qps': total_images / total_time if total_time > 0 else 0,
},
'users': user_results
}
with open(json_output, 'w', encoding='utf-8') as f:
json.dump(json_data, f, indent=2, ensure_ascii=False)
print(f"✓ 详细数据已保存到: {json_output}")
if __name__ == '__main__':
main()