""" YOLO数字识别模型训练脚本 功能说明: 使用YOLOv8在准备好的数字数据集上训练目标检测模型。 支持从预训练模型开始进行迁移学习,加速训练过程。 主要功能: - 加载YOLO预训练模型(yolov8n.pt等) - 在数字数据集上进行训练 - 自动保存最佳模型和最后模型 - 训练完成后自动验证 - 可选:在valid文件夹上进行推理测试 训练流程: 1. 加载预训练模型(ImageNet或COCO预训练) 2. 在数字数据集上微调 3. 每个epoch保存检查点 4. 根据验证集mAP保存最佳模型 5. 训练完成后加载最佳模型进行验证 输出文件: runs/digit_yolo// ├── weights/ │ ├── best.pt # 最佳模型(验证集mAP最高) │ └── last.pt # 最后一个epoch的模型 ├── results.csv # 训练指标(loss, mAP等) ├── results.png # 训练曲线图 ├── confusion_matrix.png # 混淆矩阵 └── args.yaml # 训练参数记录 训练参数说明: - epochs: 训练轮数(100-200推荐) - batch: 批次大小(根据显存调整,CPU建议8-16) - imgsz: 输入图片大小(320快速,640精确) - model: 预训练模型(yolov8n最轻量,yolov8s/m更准确) 性能优化建议: CPU训练: - batch=8-16 - imgsz=320 - workers=4 - 训练时间: ~2-3小时/100轮 GPU训练: - batch=32-64 - imgsz=640 - 训练时间: ~10-20分钟/100轮 使用示例: # 基础训练(100轮) python scripts/train_yolo.py # 长时间训练(200轮) python scripts/train_yolo.py --epochs 200 --name exp_200 # 使用更大模型 python scripts/train_yolo.py --model yolov8s.pt --epochs 150 # 高清训练 python scripts/train_yolo.py --imgsz 640 --batch 8 --name exp_hd # 自定义输出目录 python scripts/train_yolo.py \ --project my_runs \ --name my_experiment \ --epochs 150 监控训练: # 实时查看训练指标 tail -f runs/digit_yolo//results.csv # TensorBoard可视化(可选) tensorboard --logdir runs/digit_yolo 依赖环境: - ultralytics >= 8.0.0 - torch >= 2.0.0 - opencv-python 作者: Gavin Chan 版本: 1.0 日期: 2025-10-30 """ from __future__ import annotations import argparse from pathlib import Path from ultralytics import YOLO def parse_args() -> argparse.Namespace: """ 解析命令行参数 Returns: argparse.Namespace: 训练配置参数 - data: 数据集配置文件路径(dataset.yaml) - model: 预训练模型名称或路径 - epochs: 训练轮数 - imgsz: 输入图片大小 - batch: 批次大小 - project: 输出项目目录 - name: 实验名称 - valid_dir: 额外验证图片目录 """ parser = argparse.ArgumentParser(description="Train YOLO model for digit recognition") parser.add_argument("--data", type=Path, default=Path("yolo_dataset/dataset.yaml"), help="path to dataset yaml") parser.add_argument("--model", type=str, default="yolov8n.pt", help="pretrained YOLO checkpoint") parser.add_argument("--epochs", type=int, default=100, help="number of training epochs") parser.add_argument("--imgsz", type=int, default=320, help="image size") parser.add_argument("--batch", type=int, default=16, help="batch size") parser.add_argument("--project", type=str, default="runs/digit_yolo", help="training output directory") parser.add_argument("--name", type=str, default="exp", help="run name") parser.add_argument( "--valid-dir", type=Path, default=Path("valid"), help="directory with four-digit images for evaluation" ) return parser.parse_args() def main() -> None: """ 主函数:执行YOLO模型训练流程 完整流程: 1. 解析命令行参数 2. 加载YOLO预训练模型 3. 开始训练(自动保存检查点) 4. 训练完成后加载最佳模型 5. 在验证集上评估性能 6. 可选:在valid文件夹上进行推理 训练输出: - 每个epoch的训练和验证指标 - 混淆矩阵 - PR曲线 - 训练曲线图 - 最佳和最后模型权重 验证指标: - mAP50: IoU=0.5时的mAP(主要指标) - mAP50-95: IoU从0.5到0.95的平均mAP - Precision: 精确率 - Recall: 召回率 - 每个类别(数字0-9)的性能 异常处理: - FileNotFoundError: 数据集配置文件不存在 - RuntimeError: 训练失败或模型加载失败 """ args = parse_args() model = YOLO(args.model) results = model.train( data=str(args.data), epochs=args.epochs, imgsz=args.imgsz, batch=args.batch, project=args.project, name=args.name, exist_ok=True, ) print("Training complete. Summary metrics:") print(results) best_ckpt = Path(results.save_dir) / "weights" / "best.pt" if not best_ckpt.exists(): raise FileNotFoundError(f"Best checkpoint not found at {best_ckpt}") # Validate on the validation split model = YOLO(str(best_ckpt)) print("Running validation...") val_metrics = model.val(data=str(args.data), imgsz=args.imgsz, project=args.project, name=f"{args.name}_val") print(val_metrics) # Inference on the valid folder if args.valid_dir.exists(): print(f"Running inference on {args.valid_dir} ...") model.predict( source=str(args.valid_dir), imgsz=args.imgsz, save=True, save_txt=True, project=args.project, name=f"{args.name}_valid", ) print(f"Predictions saved to {Path(args.project) / f'{args.name}_valid'}") else: print(f"Valid directory {args.valid_dir} not found; skipping inference.") if __name__ == "__main__": main()