198 lines
6.1 KiB
Python
198 lines
6.1 KiB
Python
"""
|
||
YOLO数字识别模型训练脚本
|
||
|
||
功能说明:
|
||
使用YOLOv8在准备好的数字数据集上训练目标检测模型。
|
||
支持从预训练模型开始进行迁移学习,加速训练过程。
|
||
|
||
主要功能:
|
||
- 加载YOLO预训练模型(yolov8n.pt等)
|
||
- 在数字数据集上进行训练
|
||
- 自动保存最佳模型和最后模型
|
||
- 训练完成后自动验证
|
||
- 可选:在valid文件夹上进行推理测试
|
||
|
||
训练流程:
|
||
1. 加载预训练模型(ImageNet或COCO预训练)
|
||
2. 在数字数据集上微调
|
||
3. 每个epoch保存检查点
|
||
4. 根据验证集mAP保存最佳模型
|
||
5. 训练完成后加载最佳模型进行验证
|
||
|
||
输出文件:
|
||
runs/digit_yolo/<name>/
|
||
├── weights/
|
||
│ ├── best.pt # 最佳模型(验证集mAP最高)
|
||
│ └── last.pt # 最后一个epoch的模型
|
||
├── results.csv # 训练指标(loss, mAP等)
|
||
├── results.png # 训练曲线图
|
||
├── confusion_matrix.png # 混淆矩阵
|
||
└── args.yaml # 训练参数记录
|
||
|
||
训练参数说明:
|
||
- epochs: 训练轮数(100-200推荐)
|
||
- batch: 批次大小(根据显存调整,CPU建议8-16)
|
||
- imgsz: 输入图片大小(320快速,640精确)
|
||
- model: 预训练模型(yolov8n最轻量,yolov8s/m更准确)
|
||
|
||
性能优化建议:
|
||
CPU训练:
|
||
- batch=8-16
|
||
- imgsz=320
|
||
- workers=4
|
||
- 训练时间: ~2-3小时/100轮
|
||
|
||
GPU训练:
|
||
- batch=32-64
|
||
- imgsz=640
|
||
- 训练时间: ~10-20分钟/100轮
|
||
|
||
使用示例:
|
||
# 基础训练(100轮)
|
||
python scripts/train_yolo.py
|
||
|
||
# 长时间训练(200轮)
|
||
python scripts/train_yolo.py --epochs 200 --name exp_200
|
||
|
||
# 使用更大模型
|
||
python scripts/train_yolo.py --model yolov8s.pt --epochs 150
|
||
|
||
# 高清训练
|
||
python scripts/train_yolo.py --imgsz 640 --batch 8 --name exp_hd
|
||
|
||
# 自定义输出目录
|
||
python scripts/train_yolo.py \
|
||
--project my_runs \
|
||
--name my_experiment \
|
||
--epochs 150
|
||
|
||
监控训练:
|
||
# 实时查看训练指标
|
||
tail -f runs/digit_yolo/<name>/results.csv
|
||
|
||
# TensorBoard可视化(可选)
|
||
tensorboard --logdir runs/digit_yolo
|
||
|
||
依赖环境:
|
||
- ultralytics >= 8.0.0
|
||
- torch >= 2.0.0
|
||
- opencv-python
|
||
|
||
作者: Gavin Chan
|
||
版本: 1.0
|
||
日期: 2025-10-30
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
from pathlib import Path
|
||
|
||
from ultralytics import YOLO
|
||
|
||
|
||
def parse_args() -> argparse.Namespace:
|
||
"""
|
||
解析命令行参数
|
||
|
||
Returns:
|
||
argparse.Namespace: 训练配置参数
|
||
- data: 数据集配置文件路径(dataset.yaml)
|
||
- model: 预训练模型名称或路径
|
||
- epochs: 训练轮数
|
||
- imgsz: 输入图片大小
|
||
- batch: 批次大小
|
||
- project: 输出项目目录
|
||
- name: 实验名称
|
||
- valid_dir: 额外验证图片目录
|
||
"""
|
||
parser = argparse.ArgumentParser(description="Train YOLO model for digit recognition")
|
||
parser.add_argument("--data", type=Path, default=Path("yolo_dataset/dataset.yaml"), help="path to dataset yaml")
|
||
parser.add_argument("--model", type=str, default="yolov8n.pt", help="pretrained YOLO checkpoint")
|
||
parser.add_argument("--epochs", type=int, default=100, help="number of training epochs")
|
||
parser.add_argument("--imgsz", type=int, default=320, help="image size")
|
||
parser.add_argument("--batch", type=int, default=16, help="batch size")
|
||
parser.add_argument("--project", type=str, default="runs/digit_yolo", help="training output directory")
|
||
parser.add_argument("--name", type=str, default="exp", help="run name")
|
||
parser.add_argument(
|
||
"--valid-dir", type=Path, default=Path("valid"), help="directory with four-digit images for evaluation"
|
||
)
|
||
return parser.parse_args()
|
||
|
||
|
||
def main() -> None:
|
||
"""
|
||
主函数:执行YOLO模型训练流程
|
||
|
||
完整流程:
|
||
1. 解析命令行参数
|
||
2. 加载YOLO预训练模型
|
||
3. 开始训练(自动保存检查点)
|
||
4. 训练完成后加载最佳模型
|
||
5. 在验证集上评估性能
|
||
6. 可选:在valid文件夹上进行推理
|
||
|
||
训练输出:
|
||
- 每个epoch的训练和验证指标
|
||
- 混淆矩阵
|
||
- PR曲线
|
||
- 训练曲线图
|
||
- 最佳和最后模型权重
|
||
|
||
验证指标:
|
||
- mAP50: IoU=0.5时的mAP(主要指标)
|
||
- mAP50-95: IoU从0.5到0.95的平均mAP
|
||
- Precision: 精确率
|
||
- Recall: 召回率
|
||
- 每个类别(数字0-9)的性能
|
||
|
||
异常处理:
|
||
- FileNotFoundError: 数据集配置文件不存在
|
||
- RuntimeError: 训练失败或模型加载失败
|
||
"""
|
||
args = parse_args()
|
||
model = YOLO(args.model)
|
||
|
||
results = model.train(
|
||
data=str(args.data),
|
||
epochs=args.epochs,
|
||
imgsz=args.imgsz,
|
||
batch=args.batch,
|
||
project=args.project,
|
||
name=args.name,
|
||
exist_ok=True,
|
||
)
|
||
|
||
print("Training complete. Summary metrics:")
|
||
print(results)
|
||
|
||
best_ckpt = Path(results.save_dir) / "weights" / "best.pt"
|
||
if not best_ckpt.exists():
|
||
raise FileNotFoundError(f"Best checkpoint not found at {best_ckpt}")
|
||
|
||
# Validate on the validation split
|
||
model = YOLO(str(best_ckpt))
|
||
print("Running validation...")
|
||
val_metrics = model.val(data=str(args.data), imgsz=args.imgsz, project=args.project, name=f"{args.name}_val")
|
||
print(val_metrics)
|
||
|
||
# Inference on the valid folder
|
||
if args.valid_dir.exists():
|
||
print(f"Running inference on {args.valid_dir} ...")
|
||
model.predict(
|
||
source=str(args.valid_dir),
|
||
imgsz=args.imgsz,
|
||
save=True,
|
||
save_txt=True,
|
||
project=args.project,
|
||
name=f"{args.name}_valid",
|
||
)
|
||
print(f"Predictions saved to {Path(args.project) / f'{args.name}_valid'}")
|
||
else:
|
||
print(f"Valid directory {args.valid_dir} not found; skipping inference.")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|
||
|