Files
digit-cracker/scripts/train_yolo.py
2025-10-30 15:40:56 +08:00

198 lines
6.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
YOLO数字识别模型训练脚本
功能说明:
使用YOLOv8在准备好的数字数据集上训练目标检测模型。
支持从预训练模型开始进行迁移学习,加速训练过程。
主要功能:
- 加载YOLO预训练模型yolov8n.pt等
- 在数字数据集上进行训练
- 自动保存最佳模型和最后模型
- 训练完成后自动验证
- 可选在valid文件夹上进行推理测试
训练流程:
1. 加载预训练模型ImageNet或COCO预训练
2. 在数字数据集上微调
3. 每个epoch保存检查点
4. 根据验证集mAP保存最佳模型
5. 训练完成后加载最佳模型进行验证
输出文件:
runs/digit_yolo/<name>/
├── weights/
│ ├── best.pt # 最佳模型验证集mAP最高
│ └── last.pt # 最后一个epoch的模型
├── results.csv # 训练指标loss, mAP等
├── results.png # 训练曲线图
├── confusion_matrix.png # 混淆矩阵
└── args.yaml # 训练参数记录
训练参数说明:
- epochs: 训练轮数100-200推荐
- batch: 批次大小根据显存调整CPU建议8-16
- imgsz: 输入图片大小320快速640精确
- model: 预训练模型yolov8n最轻量yolov8s/m更准确
性能优化建议:
CPU训练:
- batch=8-16
- imgsz=320
- workers=4
- 训练时间: ~2-3小时/100轮
GPU训练:
- batch=32-64
- imgsz=640
- 训练时间: ~10-20分钟/100轮
使用示例:
# 基础训练100轮
python scripts/train_yolo.py
# 长时间训练200轮
python scripts/train_yolo.py --epochs 200 --name exp_200
# 使用更大模型
python scripts/train_yolo.py --model yolov8s.pt --epochs 150
# 高清训练
python scripts/train_yolo.py --imgsz 640 --batch 8 --name exp_hd
# 自定义输出目录
python scripts/train_yolo.py \
--project my_runs \
--name my_experiment \
--epochs 150
监控训练:
# 实时查看训练指标
tail -f runs/digit_yolo/<name>/results.csv
# TensorBoard可视化可选
tensorboard --logdir runs/digit_yolo
依赖环境:
- ultralytics >= 8.0.0
- torch >= 2.0.0
- opencv-python
作者: Gavin Chan
版本: 1.0
日期: 2025-10-30
"""
from __future__ import annotations
import argparse
from pathlib import Path
from ultralytics import YOLO
def parse_args() -> argparse.Namespace:
"""
解析命令行参数
Returns:
argparse.Namespace: 训练配置参数
- data: 数据集配置文件路径dataset.yaml
- model: 预训练模型名称或路径
- epochs: 训练轮数
- imgsz: 输入图片大小
- batch: 批次大小
- project: 输出项目目录
- name: 实验名称
- valid_dir: 额外验证图片目录
"""
parser = argparse.ArgumentParser(description="Train YOLO model for digit recognition")
parser.add_argument("--data", type=Path, default=Path("yolo_dataset/dataset.yaml"), help="path to dataset yaml")
parser.add_argument("--model", type=str, default="yolov8n.pt", help="pretrained YOLO checkpoint")
parser.add_argument("--epochs", type=int, default=100, help="number of training epochs")
parser.add_argument("--imgsz", type=int, default=320, help="image size")
parser.add_argument("--batch", type=int, default=16, help="batch size")
parser.add_argument("--project", type=str, default="runs/digit_yolo", help="training output directory")
parser.add_argument("--name", type=str, default="exp", help="run name")
parser.add_argument(
"--valid-dir", type=Path, default=Path("valid"), help="directory with four-digit images for evaluation"
)
return parser.parse_args()
def main() -> None:
"""
主函数执行YOLO模型训练流程
完整流程:
1. 解析命令行参数
2. 加载YOLO预训练模型
3. 开始训练(自动保存检查点)
4. 训练完成后加载最佳模型
5. 在验证集上评估性能
6. 可选在valid文件夹上进行推理
训练输出:
- 每个epoch的训练和验证指标
- 混淆矩阵
- PR曲线
- 训练曲线图
- 最佳和最后模型权重
验证指标:
- mAP50: IoU=0.5时的mAP主要指标
- mAP50-95: IoU从0.5到0.95的平均mAP
- Precision: 精确率
- Recall: 召回率
- 每个类别数字0-9的性能
异常处理:
- FileNotFoundError: 数据集配置文件不存在
- RuntimeError: 训练失败或模型加载失败
"""
args = parse_args()
model = YOLO(args.model)
results = model.train(
data=str(args.data),
epochs=args.epochs,
imgsz=args.imgsz,
batch=args.batch,
project=args.project,
name=args.name,
exist_ok=True,
)
print("Training complete. Summary metrics:")
print(results)
best_ckpt = Path(results.save_dir) / "weights" / "best.pt"
if not best_ckpt.exists():
raise FileNotFoundError(f"Best checkpoint not found at {best_ckpt}")
# Validate on the validation split
model = YOLO(str(best_ckpt))
print("Running validation...")
val_metrics = model.val(data=str(args.data), imgsz=args.imgsz, project=args.project, name=f"{args.name}_val")
print(val_metrics)
# Inference on the valid folder
if args.valid_dir.exists():
print(f"Running inference on {args.valid_dir} ...")
model.predict(
source=str(args.valid_dir),
imgsz=args.imgsz,
save=True,
save_txt=True,
project=args.project,
name=f"{args.name}_valid",
)
print(f"Predictions saved to {Path(args.project) / f'{args.name}_valid'}")
else:
print(f"Valid directory {args.valid_dir} not found; skipping inference.")
if __name__ == "__main__":
main()