""" YOLO数字识别 - 基础版本 功能说明: 使用训练好的YOLO模型识别图片中的4位阿拉伯数字。 这是基础版本,提供简单的数字检测和识别功能。 主要特性: - 批量处理图片文件夹 - 支持自定义置信度阈值 - 从左到右排序数字 - 生成可视化结果(可选) - 输出识别结果到文本文件 算法流程: 1. 加载YOLO模型 2. 对每张图片进行目标检测 3. 提取检测到的数字(0-9) 4. 按x坐标从左到右排序 5. 组合成完整数字串 适用场景: - 快速测试模型效果 - 简单的数字识别任务 - 作为改进版的基准对比 注意事项: - 不包含智能过滤,可能识别出非4位数字 - 对于复杂场景建议使用 predict_digits_improved.py 使用示例: # 基础使用 python scripts/predict_digits.py # 自定义参数 python scripts/predict_digits.py \ --model runs/digit_yolo/exp1/weights/best.pt \ --source valid \ --conf 0.25 \ --save-vis # 高清识别 python scripts/predict_digits.py --imgsz 640 --conf 0.2 作者: Gavin Chan 版本: 1.0 日期: 2025-10-30 """ from __future__ import annotations import argparse from pathlib import Path from typing import List, Tuple from ultralytics import YOLO import cv2 def parse_args() -> argparse.Namespace: """ 解析命令行参数 Returns: argparse.Namespace: 包含所有命令行参数的对象 - model: YOLO模型文件路径 - source: 待识别图片的文件夹路径 - conf: 置信度阈值(0-1之间) - imgsz: 输入图片尺寸 - output: 输出结果文件路径 - save_vis: 是否保存可视化结果 """ parser = argparse.ArgumentParser(description="识别4位数字图片") parser.add_argument( "--model", type=Path, default=Path("runs/digit_yolo/exp1/weights/best.pt"), help="训练好的YOLO模型路径" ) parser.add_argument( "--source", type=Path, default=Path("valid"), help="待识别图片的文件夹路径" ) parser.add_argument( "--conf", type=float, default=0.25, help="置信度阈值" ) parser.add_argument( "--imgsz", type=int, default=320, help="输入图片大小" ) parser.add_argument( "--output", type=Path, default=Path("results/predictions.txt"), help="输出结果文件路径" ) parser.add_argument( "--save-vis", action="store_true", help="是否保存可视化结果" ) return parser.parse_args() def extract_digits_from_predictions(results, img_width: int) -> str: """ 从YOLO预测结果中提取数字并按位置排序 处理流程: 1. 遍历所有检测框 2. 提取边界框的x坐标中心点 3. 获取每个检测框的类别(0-9)和置信度 4. 按x坐标从左到右排序 5. 组合成完整的数字字符串 Args: results: YOLO模型的预测结果对象 - results.boxes: 检测框信息 - results.boxes.xyxy: 边界框坐标 [x1, y1, x2, y2] - results.boxes.cls: 类别ID(0-9对应数字0-9) - results.boxes.conf: 置信度分数 img_width: 图片宽度(像素),用于坐标归一化(当前版本未使用) Returns: str: 识别出的数字字符串,如 "1234",可能不足或超过4位 示例: >>> results = model.predict("image.jpg")[0] >>> digits = extract_digits_from_predictions(results, 640) >>> print(digits) # "3809" """ # 提取检测框和类别 detections: List[Tuple[float, int]] = [] # (x_center, digit_class) if results.boxes is not None and len(results.boxes) > 0: boxes = results.boxes for i in range(len(boxes)): # 获取边界框坐标 (x1, y1, x2, y2) box = boxes.xyxy[i].cpu().numpy() x_center = (box[0] + box[2]) / 2 # 获取类别(数字0-9) cls = int(boxes.cls[i].cpu().numpy()) # 获取置信度 conf = float(boxes.conf[i].cpu().numpy()) detections.append((x_center, cls, conf)) # 按照x坐标从左到右排序 detections.sort(key=lambda x: x[0]) # 提取数字 digits = [str(det[1]) for det in detections] # 组合成4位数字字符串 result = "".join(digits) return result def predict_single_image(model: YOLO, image_path: Path, conf: float, imgsz: int) -> Tuple[str, float]: """ 预测单张图片中的数字 处理流程: 1. 使用OpenCV读取图片获取尺寸信息 2. 调用YOLO模型进行目标检测 3. 提取并排序检测到的数字 4. 计算平均置信度作为质量指标 Args: model (YOLO): 已加载的YOLO模型对象 image_path (Path): 图片文件的完整路径 conf (float): 置信度阈值(0-1),低于此值的检测将被过滤 imgsz (int): 模型输入图片大小,如320或640 Returns: Tuple[str, float]: 二元组 - str: 识别出的数字字符串,如"1234"或"567"(可能不足4位) - float: 所有检测框的平均置信度,范围0-1 异常处理: - 如果图片无法读取,返回 ("", 0.0) 并打印警告 - 如果没有检测到任何数字,返回 ("", 0.0) 示例: >>> model = YOLO("best.pt") >>> digits, conf = predict_single_image(model, Path("test.jpg"), 0.25, 320) >>> print(f"识别结果: {digits}, 置信度: {conf:.3f}") 识别结果: 3809, 置信度: 0.584 """ # 读取图片获取宽度 img = cv2.imread(str(image_path)) if img is None: print(f"警告:无法读取图片 {image_path}") return "", 0.0 img_height, img_width = img.shape[:2] # 进行预测 results = model.predict( source=str(image_path), conf=conf, imgsz=imgsz, verbose=False )[0] # 提取数字 digits = extract_digits_from_predictions(results, img_width) # 计算平均置信度 avg_conf = 0.0 if results.boxes is not None and len(results.boxes) > 0: confs = results.boxes.conf.cpu().numpy() avg_conf = float(confs.mean()) return digits, avg_conf def main() -> None: """ 主函数:执行批量数字识别流程 完整流程: 1. 解析命令行参数 2. 验证模型文件和图片目录是否存在 3. 加载YOLO模型 4. 遍历所有图片文件进行识别 5. 统计识别结果(正确率、置信度等) 6. 保存结果到文本文件 7. 可选:生成带标注的可视化图片 输出格式: 控制台输出: - 每张图片的识别结果 - 统计信息(正确率等) - 文件保存路径 文本文件(results/predictions.txt): 文件名 识别结果 置信度 数字个数 YZM.jpeg 3809 0.584 4 ... 异常处理: - FileNotFoundError: 模型或图片目录不存在 - 其他异常会向上传播 注意: - 需要预先安装 ultralytics 和 opencv-python - 模型文件需要是训练好的 .pt 格式 - 支持的图片格式: .jpg, .jpeg, .png, .bmp """ args = parse_args() # 检查模型文件 if not args.model.exists(): raise FileNotFoundError(f"模型文件不存在: {args.model}") # 检查源文件夹 if not args.source.exists(): raise FileNotFoundError(f"源文件夹不存在: {args.source}") # 加载模型 print(f"加载模型: {args.model}") model = YOLO(str(args.model)) # 获取所有图片文件 image_extensions = [".jpg", ".jpeg", ".png", ".bmp"] image_files = [] for ext in image_extensions: image_files.extend(args.source.glob(f"*{ext}")) image_files.extend(args.source.glob(f"*{ext.upper()}")) image_files = sorted(image_files) if not image_files: print(f"在 {args.source} 中没有找到图片文件") return print(f"找到 {len(image_files)} 张图片") print("-" * 80) # 预测结果 results = [] for image_path in image_files: digits, conf = predict_single_image(model, image_path, args.conf, args.imgsz) # 检查是否识别出4位数字 if len(digits) != 4: status = f"⚠️ 检测到 {len(digits)} 位数字" else: status = "✓" result_line = f"{image_path.name:<20} -> {digits:<6} (置信度: {conf:.3f}) {status}" print(result_line) results.append({ "filename": image_path.name, "digits": digits, "confidence": conf, "digit_count": len(digits) }) print("-" * 80) print(f"识别完成!") # 统计信息 correct_count = sum(1 for r in results if r["digit_count"] == 4) print(f"正确识别4位数字: {correct_count}/{len(results)} ({correct_count/len(results)*100:.1f}%)") # 保存结果 args.output.parent.mkdir(parents=True, exist_ok=True) with args.output.open("w", encoding="utf-8") as f: f.write("文件名\t识别结果\t置信度\t数字个数\n") for r in results: f.write(f"{r['filename']}\t{r['digits']}\t{r['confidence']:.3f}\t{r['digit_count']}\n") print(f"结果已保存到: {args.output}") # 如果需要保存可视化结果 if args.save_vis: print("\n生成可视化结果...") output_dir = args.output.parent / "visualizations" model.predict( source=str(args.source), conf=args.conf, imgsz=args.imgsz, save=True, project=str(output_dir.parent), name=output_dir.name, exist_ok=True ) print(f"可视化结果已保存到: {output_dir}") if __name__ == "__main__": main()