Files
digit-cracker/scripts/predict_digits.py
2025-10-30 15:40:56 +08:00

347 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
YOLO数字识别 - 基础版本
功能说明:
使用训练好的YOLO模型识别图片中的4位阿拉伯数字。
这是基础版本,提供简单的数字检测和识别功能。
主要特性:
- 批量处理图片文件夹
- 支持自定义置信度阈值
- 从左到右排序数字
- 生成可视化结果(可选)
- 输出识别结果到文本文件
算法流程:
1. 加载YOLO模型
2. 对每张图片进行目标检测
3. 提取检测到的数字0-9
4. 按x坐标从左到右排序
5. 组合成完整数字串
适用场景:
- 快速测试模型效果
- 简单的数字识别任务
- 作为改进版的基准对比
注意事项:
- 不包含智能过滤可能识别出非4位数字
- 对于复杂场景建议使用 predict_digits_improved.py
使用示例:
# 基础使用
python scripts/predict_digits.py
# 自定义参数
python scripts/predict_digits.py \
--model runs/digit_yolo/exp1/weights/best.pt \
--source valid \
--conf 0.25 \
--save-vis
# 高清识别
python scripts/predict_digits.py --imgsz 640 --conf 0.2
作者: Gavin Chan
版本: 1.0
日期: 2025-10-30
"""
from __future__ import annotations
import argparse
from pathlib import Path
from typing import List, Tuple
from ultralytics import YOLO
import cv2
def parse_args() -> argparse.Namespace:
"""
解析命令行参数
Returns:
argparse.Namespace: 包含所有命令行参数的对象
- model: YOLO模型文件路径
- source: 待识别图片的文件夹路径
- conf: 置信度阈值0-1之间
- imgsz: 输入图片尺寸
- output: 输出结果文件路径
- save_vis: 是否保存可视化结果
"""
parser = argparse.ArgumentParser(description="识别4位数字图片")
parser.add_argument(
"--model",
type=Path,
default=Path("runs/digit_yolo/exp1/weights/best.pt"),
help="训练好的YOLO模型路径"
)
parser.add_argument(
"--source",
type=Path,
default=Path("valid"),
help="待识别图片的文件夹路径"
)
parser.add_argument(
"--conf",
type=float,
default=0.25,
help="置信度阈值"
)
parser.add_argument(
"--imgsz",
type=int,
default=320,
help="输入图片大小"
)
parser.add_argument(
"--output",
type=Path,
default=Path("results/predictions.txt"),
help="输出结果文件路径"
)
parser.add_argument(
"--save-vis",
action="store_true",
help="是否保存可视化结果"
)
return parser.parse_args()
def extract_digits_from_predictions(results, img_width: int) -> str:
"""
从YOLO预测结果中提取数字并按位置排序
处理流程:
1. 遍历所有检测框
2. 提取边界框的x坐标中心点
3. 获取每个检测框的类别0-9和置信度
4. 按x坐标从左到右排序
5. 组合成完整的数字字符串
Args:
results: YOLO模型的预测结果对象
- results.boxes: 检测框信息
- results.boxes.xyxy: 边界框坐标 [x1, y1, x2, y2]
- results.boxes.cls: 类别ID0-9对应数字0-9
- results.boxes.conf: 置信度分数
img_width: 图片宽度(像素),用于坐标归一化(当前版本未使用)
Returns:
str: 识别出的数字字符串,如 "1234"可能不足或超过4位
示例:
>>> results = model.predict("image.jpg")[0]
>>> digits = extract_digits_from_predictions(results, 640)
>>> print(digits) # "3809"
"""
# 提取检测框和类别
detections: List[Tuple[float, int]] = [] # (x_center, digit_class)
if results.boxes is not None and len(results.boxes) > 0:
boxes = results.boxes
for i in range(len(boxes)):
# 获取边界框坐标 (x1, y1, x2, y2)
box = boxes.xyxy[i].cpu().numpy()
x_center = (box[0] + box[2]) / 2
# 获取类别数字0-9
cls = int(boxes.cls[i].cpu().numpy())
# 获取置信度
conf = float(boxes.conf[i].cpu().numpy())
detections.append((x_center, cls, conf))
# 按照x坐标从左到右排序
detections.sort(key=lambda x: x[0])
# 提取数字
digits = [str(det[1]) for det in detections]
# 组合成4位数字字符串
result = "".join(digits)
return result
def predict_single_image(model: YOLO, image_path: Path, conf: float, imgsz: int) -> Tuple[str, float]:
"""
预测单张图片中的数字
处理流程:
1. 使用OpenCV读取图片获取尺寸信息
2. 调用YOLO模型进行目标检测
3. 提取并排序检测到的数字
4. 计算平均置信度作为质量指标
Args:
model (YOLO): 已加载的YOLO模型对象
image_path (Path): 图片文件的完整路径
conf (float): 置信度阈值0-1低于此值的检测将被过滤
imgsz (int): 模型输入图片大小如320或640
Returns:
Tuple[str, float]: 二元组
- str: 识别出的数字字符串,如"1234""567"可能不足4位
- float: 所有检测框的平均置信度范围0-1
异常处理:
- 如果图片无法读取,返回 ("", 0.0) 并打印警告
- 如果没有检测到任何数字,返回 ("", 0.0)
示例:
>>> model = YOLO("best.pt")
>>> digits, conf = predict_single_image(model, Path("test.jpg"), 0.25, 320)
>>> print(f"识别结果: {digits}, 置信度: {conf:.3f}")
识别结果: 3809, 置信度: 0.584
"""
# 读取图片获取宽度
img = cv2.imread(str(image_path))
if img is None:
print(f"警告:无法读取图片 {image_path}")
return "", 0.0
img_height, img_width = img.shape[:2]
# 进行预测
results = model.predict(
source=str(image_path),
conf=conf,
imgsz=imgsz,
verbose=False
)[0]
# 提取数字
digits = extract_digits_from_predictions(results, img_width)
# 计算平均置信度
avg_conf = 0.0
if results.boxes is not None and len(results.boxes) > 0:
confs = results.boxes.conf.cpu().numpy()
avg_conf = float(confs.mean())
return digits, avg_conf
def main() -> None:
"""
主函数:执行批量数字识别流程
完整流程:
1. 解析命令行参数
2. 验证模型文件和图片目录是否存在
3. 加载YOLO模型
4. 遍历所有图片文件进行识别
5. 统计识别结果(正确率、置信度等)
6. 保存结果到文本文件
7. 可选:生成带标注的可视化图片
输出格式:
控制台输出:
- 每张图片的识别结果
- 统计信息(正确率等)
- 文件保存路径
文本文件results/predictions.txt:
文件名 识别结果 置信度 数字个数
YZM.jpeg 3809 0.584 4
...
异常处理:
- FileNotFoundError: 模型或图片目录不存在
- 其他异常会向上传播
注意:
- 需要预先安装 ultralytics 和 opencv-python
- 模型文件需要是训练好的 .pt 格式
- 支持的图片格式: .jpg, .jpeg, .png, .bmp
"""
args = parse_args()
# 检查模型文件
if not args.model.exists():
raise FileNotFoundError(f"模型文件不存在: {args.model}")
# 检查源文件夹
if not args.source.exists():
raise FileNotFoundError(f"源文件夹不存在: {args.source}")
# 加载模型
print(f"加载模型: {args.model}")
model = YOLO(str(args.model))
# 获取所有图片文件
image_extensions = [".jpg", ".jpeg", ".png", ".bmp"]
image_files = []
for ext in image_extensions:
image_files.extend(args.source.glob(f"*{ext}"))
image_files.extend(args.source.glob(f"*{ext.upper()}"))
image_files = sorted(image_files)
if not image_files:
print(f"{args.source} 中没有找到图片文件")
return
print(f"找到 {len(image_files)} 张图片")
print("-" * 80)
# 预测结果
results = []
for image_path in image_files:
digits, conf = predict_single_image(model, image_path, args.conf, args.imgsz)
# 检查是否识别出4位数字
if len(digits) != 4:
status = f"⚠️ 检测到 {len(digits)} 位数字"
else:
status = ""
result_line = f"{image_path.name:<20} -> {digits:<6} (置信度: {conf:.3f}) {status}"
print(result_line)
results.append({
"filename": image_path.name,
"digits": digits,
"confidence": conf,
"digit_count": len(digits)
})
print("-" * 80)
print(f"识别完成!")
# 统计信息
correct_count = sum(1 for r in results if r["digit_count"] == 4)
print(f"正确识别4位数字: {correct_count}/{len(results)} ({correct_count/len(results)*100:.1f}%)")
# 保存结果
args.output.parent.mkdir(parents=True, exist_ok=True)
with args.output.open("w", encoding="utf-8") as f:
f.write("文件名\t识别结果\t置信度\t数字个数\n")
for r in results:
f.write(f"{r['filename']}\t{r['digits']}\t{r['confidence']:.3f}\t{r['digit_count']}\n")
print(f"结果已保存到: {args.output}")
# 如果需要保存可视化结果
if args.save_vis:
print("\n生成可视化结果...")
output_dir = args.output.parent / "visualizations"
model.predict(
source=str(args.source),
conf=args.conf,
imgsz=args.imgsz,
save=True,
project=str(output_dir.parent),
name=output_dir.name,
exist_ok=True
)
print(f"可视化结果已保存到: {output_dir}")
if __name__ == "__main__":
main()