first commit
This commit is contained in:
197
scripts/train_yolo.py
Normal file
197
scripts/train_yolo.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""
|
||||
YOLO数字识别模型训练脚本
|
||||
|
||||
功能说明:
|
||||
使用YOLOv8在准备好的数字数据集上训练目标检测模型。
|
||||
支持从预训练模型开始进行迁移学习,加速训练过程。
|
||||
|
||||
主要功能:
|
||||
- 加载YOLO预训练模型(yolov8n.pt等)
|
||||
- 在数字数据集上进行训练
|
||||
- 自动保存最佳模型和最后模型
|
||||
- 训练完成后自动验证
|
||||
- 可选:在valid文件夹上进行推理测试
|
||||
|
||||
训练流程:
|
||||
1. 加载预训练模型(ImageNet或COCO预训练)
|
||||
2. 在数字数据集上微调
|
||||
3. 每个epoch保存检查点
|
||||
4. 根据验证集mAP保存最佳模型
|
||||
5. 训练完成后加载最佳模型进行验证
|
||||
|
||||
输出文件:
|
||||
runs/digit_yolo/<name>/
|
||||
├── weights/
|
||||
│ ├── best.pt # 最佳模型(验证集mAP最高)
|
||||
│ └── last.pt # 最后一个epoch的模型
|
||||
├── results.csv # 训练指标(loss, mAP等)
|
||||
├── results.png # 训练曲线图
|
||||
├── confusion_matrix.png # 混淆矩阵
|
||||
└── args.yaml # 训练参数记录
|
||||
|
||||
训练参数说明:
|
||||
- epochs: 训练轮数(100-200推荐)
|
||||
- batch: 批次大小(根据显存调整,CPU建议8-16)
|
||||
- imgsz: 输入图片大小(320快速,640精确)
|
||||
- model: 预训练模型(yolov8n最轻量,yolov8s/m更准确)
|
||||
|
||||
性能优化建议:
|
||||
CPU训练:
|
||||
- batch=8-16
|
||||
- imgsz=320
|
||||
- workers=4
|
||||
- 训练时间: ~2-3小时/100轮
|
||||
|
||||
GPU训练:
|
||||
- batch=32-64
|
||||
- imgsz=640
|
||||
- 训练时间: ~10-20分钟/100轮
|
||||
|
||||
使用示例:
|
||||
# 基础训练(100轮)
|
||||
python scripts/train_yolo.py
|
||||
|
||||
# 长时间训练(200轮)
|
||||
python scripts/train_yolo.py --epochs 200 --name exp_200
|
||||
|
||||
# 使用更大模型
|
||||
python scripts/train_yolo.py --model yolov8s.pt --epochs 150
|
||||
|
||||
# 高清训练
|
||||
python scripts/train_yolo.py --imgsz 640 --batch 8 --name exp_hd
|
||||
|
||||
# 自定义输出目录
|
||||
python scripts/train_yolo.py \
|
||||
--project my_runs \
|
||||
--name my_experiment \
|
||||
--epochs 150
|
||||
|
||||
监控训练:
|
||||
# 实时查看训练指标
|
||||
tail -f runs/digit_yolo/<name>/results.csv
|
||||
|
||||
# TensorBoard可视化(可选)
|
||||
tensorboard --logdir runs/digit_yolo
|
||||
|
||||
依赖环境:
|
||||
- ultralytics >= 8.0.0
|
||||
- torch >= 2.0.0
|
||||
- opencv-python
|
||||
|
||||
作者: Gavin Chan
|
||||
版本: 1.0
|
||||
日期: 2025-10-30
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
from ultralytics import YOLO
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
"""
|
||||
解析命令行参数
|
||||
|
||||
Returns:
|
||||
argparse.Namespace: 训练配置参数
|
||||
- data: 数据集配置文件路径(dataset.yaml)
|
||||
- model: 预训练模型名称或路径
|
||||
- epochs: 训练轮数
|
||||
- imgsz: 输入图片大小
|
||||
- batch: 批次大小
|
||||
- project: 输出项目目录
|
||||
- name: 实验名称
|
||||
- valid_dir: 额外验证图片目录
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="Train YOLO model for digit recognition")
|
||||
parser.add_argument("--data", type=Path, default=Path("yolo_dataset/dataset.yaml"), help="path to dataset yaml")
|
||||
parser.add_argument("--model", type=str, default="yolov8n.pt", help="pretrained YOLO checkpoint")
|
||||
parser.add_argument("--epochs", type=int, default=100, help="number of training epochs")
|
||||
parser.add_argument("--imgsz", type=int, default=320, help="image size")
|
||||
parser.add_argument("--batch", type=int, default=16, help="batch size")
|
||||
parser.add_argument("--project", type=str, default="runs/digit_yolo", help="training output directory")
|
||||
parser.add_argument("--name", type=str, default="exp", help="run name")
|
||||
parser.add_argument(
|
||||
"--valid-dir", type=Path, default=Path("valid"), help="directory with four-digit images for evaluation"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""
|
||||
主函数:执行YOLO模型训练流程
|
||||
|
||||
完整流程:
|
||||
1. 解析命令行参数
|
||||
2. 加载YOLO预训练模型
|
||||
3. 开始训练(自动保存检查点)
|
||||
4. 训练完成后加载最佳模型
|
||||
5. 在验证集上评估性能
|
||||
6. 可选:在valid文件夹上进行推理
|
||||
|
||||
训练输出:
|
||||
- 每个epoch的训练和验证指标
|
||||
- 混淆矩阵
|
||||
- PR曲线
|
||||
- 训练曲线图
|
||||
- 最佳和最后模型权重
|
||||
|
||||
验证指标:
|
||||
- mAP50: IoU=0.5时的mAP(主要指标)
|
||||
- mAP50-95: IoU从0.5到0.95的平均mAP
|
||||
- Precision: 精确率
|
||||
- Recall: 召回率
|
||||
- 每个类别(数字0-9)的性能
|
||||
|
||||
异常处理:
|
||||
- FileNotFoundError: 数据集配置文件不存在
|
||||
- RuntimeError: 训练失败或模型加载失败
|
||||
"""
|
||||
args = parse_args()
|
||||
model = YOLO(args.model)
|
||||
|
||||
results = model.train(
|
||||
data=str(args.data),
|
||||
epochs=args.epochs,
|
||||
imgsz=args.imgsz,
|
||||
batch=args.batch,
|
||||
project=args.project,
|
||||
name=args.name,
|
||||
exist_ok=True,
|
||||
)
|
||||
|
||||
print("Training complete. Summary metrics:")
|
||||
print(results)
|
||||
|
||||
best_ckpt = Path(results.save_dir) / "weights" / "best.pt"
|
||||
if not best_ckpt.exists():
|
||||
raise FileNotFoundError(f"Best checkpoint not found at {best_ckpt}")
|
||||
|
||||
# Validate on the validation split
|
||||
model = YOLO(str(best_ckpt))
|
||||
print("Running validation...")
|
||||
val_metrics = model.val(data=str(args.data), imgsz=args.imgsz, project=args.project, name=f"{args.name}_val")
|
||||
print(val_metrics)
|
||||
|
||||
# Inference on the valid folder
|
||||
if args.valid_dir.exists():
|
||||
print(f"Running inference on {args.valid_dir} ...")
|
||||
model.predict(
|
||||
source=str(args.valid_dir),
|
||||
imgsz=args.imgsz,
|
||||
save=True,
|
||||
save_txt=True,
|
||||
project=args.project,
|
||||
name=f"{args.name}_valid",
|
||||
)
|
||||
print(f"Predictions saved to {Path(args.project) / f'{args.name}_valid'}")
|
||||
else:
|
||||
print(f"Valid directory {args.valid_dir} not found; skipping inference.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user