first commit
This commit is contained in:
346
scripts/predict_digits.py
Normal file
346
scripts/predict_digits.py
Normal file
@@ -0,0 +1,346 @@
|
||||
"""
|
||||
YOLO数字识别 - 基础版本
|
||||
|
||||
功能说明:
|
||||
使用训练好的YOLO模型识别图片中的4位阿拉伯数字。
|
||||
这是基础版本,提供简单的数字检测和识别功能。
|
||||
|
||||
主要特性:
|
||||
- 批量处理图片文件夹
|
||||
- 支持自定义置信度阈值
|
||||
- 从左到右排序数字
|
||||
- 生成可视化结果(可选)
|
||||
- 输出识别结果到文本文件
|
||||
|
||||
算法流程:
|
||||
1. 加载YOLO模型
|
||||
2. 对每张图片进行目标检测
|
||||
3. 提取检测到的数字(0-9)
|
||||
4. 按x坐标从左到右排序
|
||||
5. 组合成完整数字串
|
||||
|
||||
适用场景:
|
||||
- 快速测试模型效果
|
||||
- 简单的数字识别任务
|
||||
- 作为改进版的基准对比
|
||||
|
||||
注意事项:
|
||||
- 不包含智能过滤,可能识别出非4位数字
|
||||
- 对于复杂场景建议使用 predict_digits_improved.py
|
||||
|
||||
使用示例:
|
||||
# 基础使用
|
||||
python scripts/predict_digits.py
|
||||
|
||||
# 自定义参数
|
||||
python scripts/predict_digits.py \
|
||||
--model runs/digit_yolo/exp1/weights/best.pt \
|
||||
--source valid \
|
||||
--conf 0.25 \
|
||||
--save-vis
|
||||
|
||||
# 高清识别
|
||||
python scripts/predict_digits.py --imgsz 640 --conf 0.2
|
||||
|
||||
作者: Gavin Chan
|
||||
版本: 1.0
|
||||
日期: 2025-10-30
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
from ultralytics import YOLO
|
||||
import cv2
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
"""
|
||||
解析命令行参数
|
||||
|
||||
Returns:
|
||||
argparse.Namespace: 包含所有命令行参数的对象
|
||||
- model: YOLO模型文件路径
|
||||
- source: 待识别图片的文件夹路径
|
||||
- conf: 置信度阈值(0-1之间)
|
||||
- imgsz: 输入图片尺寸
|
||||
- output: 输出结果文件路径
|
||||
- save_vis: 是否保存可视化结果
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="识别4位数字图片")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=Path,
|
||||
default=Path("runs/digit_yolo/exp1/weights/best.pt"),
|
||||
help="训练好的YOLO模型路径"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--source",
|
||||
type=Path,
|
||||
default=Path("valid"),
|
||||
help="待识别图片的文件夹路径"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--conf",
|
||||
type=float,
|
||||
default=0.25,
|
||||
help="置信度阈值"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--imgsz",
|
||||
type=int,
|
||||
default=320,
|
||||
help="输入图片大小"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=Path,
|
||||
default=Path("results/predictions.txt"),
|
||||
help="输出结果文件路径"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-vis",
|
||||
action="store_true",
|
||||
help="是否保存可视化结果"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def extract_digits_from_predictions(results, img_width: int) -> str:
|
||||
"""
|
||||
从YOLO预测结果中提取数字并按位置排序
|
||||
|
||||
处理流程:
|
||||
1. 遍历所有检测框
|
||||
2. 提取边界框的x坐标中心点
|
||||
3. 获取每个检测框的类别(0-9)和置信度
|
||||
4. 按x坐标从左到右排序
|
||||
5. 组合成完整的数字字符串
|
||||
|
||||
Args:
|
||||
results: YOLO模型的预测结果对象
|
||||
- results.boxes: 检测框信息
|
||||
- results.boxes.xyxy: 边界框坐标 [x1, y1, x2, y2]
|
||||
- results.boxes.cls: 类别ID(0-9对应数字0-9)
|
||||
- results.boxes.conf: 置信度分数
|
||||
img_width: 图片宽度(像素),用于坐标归一化(当前版本未使用)
|
||||
|
||||
Returns:
|
||||
str: 识别出的数字字符串,如 "1234",可能不足或超过4位
|
||||
|
||||
示例:
|
||||
>>> results = model.predict("image.jpg")[0]
|
||||
>>> digits = extract_digits_from_predictions(results, 640)
|
||||
>>> print(digits) # "3809"
|
||||
"""
|
||||
# 提取检测框和类别
|
||||
detections: List[Tuple[float, int]] = [] # (x_center, digit_class)
|
||||
|
||||
if results.boxes is not None and len(results.boxes) > 0:
|
||||
boxes = results.boxes
|
||||
for i in range(len(boxes)):
|
||||
# 获取边界框坐标 (x1, y1, x2, y2)
|
||||
box = boxes.xyxy[i].cpu().numpy()
|
||||
x_center = (box[0] + box[2]) / 2
|
||||
|
||||
# 获取类别(数字0-9)
|
||||
cls = int(boxes.cls[i].cpu().numpy())
|
||||
|
||||
# 获取置信度
|
||||
conf = float(boxes.conf[i].cpu().numpy())
|
||||
|
||||
detections.append((x_center, cls, conf))
|
||||
|
||||
# 按照x坐标从左到右排序
|
||||
detections.sort(key=lambda x: x[0])
|
||||
|
||||
# 提取数字
|
||||
digits = [str(det[1]) for det in detections]
|
||||
|
||||
# 组合成4位数字字符串
|
||||
result = "".join(digits)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def predict_single_image(model: YOLO, image_path: Path, conf: float, imgsz: int) -> Tuple[str, float]:
|
||||
"""
|
||||
预测单张图片中的数字
|
||||
|
||||
处理流程:
|
||||
1. 使用OpenCV读取图片获取尺寸信息
|
||||
2. 调用YOLO模型进行目标检测
|
||||
3. 提取并排序检测到的数字
|
||||
4. 计算平均置信度作为质量指标
|
||||
|
||||
Args:
|
||||
model (YOLO): 已加载的YOLO模型对象
|
||||
image_path (Path): 图片文件的完整路径
|
||||
conf (float): 置信度阈值(0-1),低于此值的检测将被过滤
|
||||
imgsz (int): 模型输入图片大小,如320或640
|
||||
|
||||
Returns:
|
||||
Tuple[str, float]: 二元组
|
||||
- str: 识别出的数字字符串,如"1234"或"567"(可能不足4位)
|
||||
- float: 所有检测框的平均置信度,范围0-1
|
||||
|
||||
异常处理:
|
||||
- 如果图片无法读取,返回 ("", 0.0) 并打印警告
|
||||
- 如果没有检测到任何数字,返回 ("", 0.0)
|
||||
|
||||
示例:
|
||||
>>> model = YOLO("best.pt")
|
||||
>>> digits, conf = predict_single_image(model, Path("test.jpg"), 0.25, 320)
|
||||
>>> print(f"识别结果: {digits}, 置信度: {conf:.3f}")
|
||||
识别结果: 3809, 置信度: 0.584
|
||||
"""
|
||||
# 读取图片获取宽度
|
||||
img = cv2.imread(str(image_path))
|
||||
if img is None:
|
||||
print(f"警告:无法读取图片 {image_path}")
|
||||
return "", 0.0
|
||||
|
||||
img_height, img_width = img.shape[:2]
|
||||
|
||||
# 进行预测
|
||||
results = model.predict(
|
||||
source=str(image_path),
|
||||
conf=conf,
|
||||
imgsz=imgsz,
|
||||
verbose=False
|
||||
)[0]
|
||||
|
||||
# 提取数字
|
||||
digits = extract_digits_from_predictions(results, img_width)
|
||||
|
||||
# 计算平均置信度
|
||||
avg_conf = 0.0
|
||||
if results.boxes is not None and len(results.boxes) > 0:
|
||||
confs = results.boxes.conf.cpu().numpy()
|
||||
avg_conf = float(confs.mean())
|
||||
|
||||
return digits, avg_conf
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""
|
||||
主函数:执行批量数字识别流程
|
||||
|
||||
完整流程:
|
||||
1. 解析命令行参数
|
||||
2. 验证模型文件和图片目录是否存在
|
||||
3. 加载YOLO模型
|
||||
4. 遍历所有图片文件进行识别
|
||||
5. 统计识别结果(正确率、置信度等)
|
||||
6. 保存结果到文本文件
|
||||
7. 可选:生成带标注的可视化图片
|
||||
|
||||
输出格式:
|
||||
控制台输出:
|
||||
- 每张图片的识别结果
|
||||
- 统计信息(正确率等)
|
||||
- 文件保存路径
|
||||
|
||||
文本文件(results/predictions.txt):
|
||||
文件名 识别结果 置信度 数字个数
|
||||
YZM.jpeg 3809 0.584 4
|
||||
...
|
||||
|
||||
异常处理:
|
||||
- FileNotFoundError: 模型或图片目录不存在
|
||||
- 其他异常会向上传播
|
||||
|
||||
注意:
|
||||
- 需要预先安装 ultralytics 和 opencv-python
|
||||
- 模型文件需要是训练好的 .pt 格式
|
||||
- 支持的图片格式: .jpg, .jpeg, .png, .bmp
|
||||
"""
|
||||
args = parse_args()
|
||||
|
||||
# 检查模型文件
|
||||
if not args.model.exists():
|
||||
raise FileNotFoundError(f"模型文件不存在: {args.model}")
|
||||
|
||||
# 检查源文件夹
|
||||
if not args.source.exists():
|
||||
raise FileNotFoundError(f"源文件夹不存在: {args.source}")
|
||||
|
||||
# 加载模型
|
||||
print(f"加载模型: {args.model}")
|
||||
model = YOLO(str(args.model))
|
||||
|
||||
# 获取所有图片文件
|
||||
image_extensions = [".jpg", ".jpeg", ".png", ".bmp"]
|
||||
image_files = []
|
||||
for ext in image_extensions:
|
||||
image_files.extend(args.source.glob(f"*{ext}"))
|
||||
image_files.extend(args.source.glob(f"*{ext.upper()}"))
|
||||
|
||||
image_files = sorted(image_files)
|
||||
|
||||
if not image_files:
|
||||
print(f"在 {args.source} 中没有找到图片文件")
|
||||
return
|
||||
|
||||
print(f"找到 {len(image_files)} 张图片")
|
||||
print("-" * 80)
|
||||
|
||||
# 预测结果
|
||||
results = []
|
||||
|
||||
for image_path in image_files:
|
||||
digits, conf = predict_single_image(model, image_path, args.conf, args.imgsz)
|
||||
|
||||
# 检查是否识别出4位数字
|
||||
if len(digits) != 4:
|
||||
status = f"⚠️ 检测到 {len(digits)} 位数字"
|
||||
else:
|
||||
status = "✓"
|
||||
|
||||
result_line = f"{image_path.name:<20} -> {digits:<6} (置信度: {conf:.3f}) {status}"
|
||||
print(result_line)
|
||||
|
||||
results.append({
|
||||
"filename": image_path.name,
|
||||
"digits": digits,
|
||||
"confidence": conf,
|
||||
"digit_count": len(digits)
|
||||
})
|
||||
|
||||
print("-" * 80)
|
||||
print(f"识别完成!")
|
||||
|
||||
# 统计信息
|
||||
correct_count = sum(1 for r in results if r["digit_count"] == 4)
|
||||
print(f"正确识别4位数字: {correct_count}/{len(results)} ({correct_count/len(results)*100:.1f}%)")
|
||||
|
||||
# 保存结果
|
||||
args.output.parent.mkdir(parents=True, exist_ok=True)
|
||||
with args.output.open("w", encoding="utf-8") as f:
|
||||
f.write("文件名\t识别结果\t置信度\t数字个数\n")
|
||||
for r in results:
|
||||
f.write(f"{r['filename']}\t{r['digits']}\t{r['confidence']:.3f}\t{r['digit_count']}\n")
|
||||
|
||||
print(f"结果已保存到: {args.output}")
|
||||
|
||||
# 如果需要保存可视化结果
|
||||
if args.save_vis:
|
||||
print("\n生成可视化结果...")
|
||||
output_dir = args.output.parent / "visualizations"
|
||||
model.predict(
|
||||
source=str(args.source),
|
||||
conf=args.conf,
|
||||
imgsz=args.imgsz,
|
||||
save=True,
|
||||
project=str(output_dir.parent),
|
||||
name=output_dir.name,
|
||||
exist_ok=True
|
||||
)
|
||||
print(f"可视化结果已保存到: {output_dir}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user