first commit

This commit is contained in:
douboer
2025-10-30 15:40:56 +08:00
parent fe4a3e7cbf
commit 2fb4b22328
344 changed files with 8595 additions and 567 deletions

197
scripts/train_yolo.py Normal file
View File

@@ -0,0 +1,197 @@
"""
YOLO数字识别模型训练脚本
功能说明:
使用YOLOv8在准备好的数字数据集上训练目标检测模型。
支持从预训练模型开始进行迁移学习,加速训练过程。
主要功能:
- 加载YOLO预训练模型yolov8n.pt等
- 在数字数据集上进行训练
- 自动保存最佳模型和最后模型
- 训练完成后自动验证
- 可选在valid文件夹上进行推理测试
训练流程:
1. 加载预训练模型ImageNet或COCO预训练
2. 在数字数据集上微调
3. 每个epoch保存检查点
4. 根据验证集mAP保存最佳模型
5. 训练完成后加载最佳模型进行验证
输出文件:
runs/digit_yolo/<name>/
├── weights/
│ ├── best.pt # 最佳模型验证集mAP最高
│ └── last.pt # 最后一个epoch的模型
├── results.csv # 训练指标loss, mAP等
├── results.png # 训练曲线图
├── confusion_matrix.png # 混淆矩阵
└── args.yaml # 训练参数记录
训练参数说明:
- epochs: 训练轮数100-200推荐
- batch: 批次大小根据显存调整CPU建议8-16
- imgsz: 输入图片大小320快速640精确
- model: 预训练模型yolov8n最轻量yolov8s/m更准确
性能优化建议:
CPU训练:
- batch=8-16
- imgsz=320
- workers=4
- 训练时间: ~2-3小时/100轮
GPU训练:
- batch=32-64
- imgsz=640
- 训练时间: ~10-20分钟/100轮
使用示例:
# 基础训练100轮
python scripts/train_yolo.py
# 长时间训练200轮
python scripts/train_yolo.py --epochs 200 --name exp_200
# 使用更大模型
python scripts/train_yolo.py --model yolov8s.pt --epochs 150
# 高清训练
python scripts/train_yolo.py --imgsz 640 --batch 8 --name exp_hd
# 自定义输出目录
python scripts/train_yolo.py \
--project my_runs \
--name my_experiment \
--epochs 150
监控训练:
# 实时查看训练指标
tail -f runs/digit_yolo/<name>/results.csv
# TensorBoard可视化可选
tensorboard --logdir runs/digit_yolo
依赖环境:
- ultralytics >= 8.0.0
- torch >= 2.0.0
- opencv-python
作者: Gavin Chan
版本: 1.0
日期: 2025-10-30
"""
from __future__ import annotations
import argparse
from pathlib import Path
from ultralytics import YOLO
def parse_args() -> argparse.Namespace:
"""
解析命令行参数
Returns:
argparse.Namespace: 训练配置参数
- data: 数据集配置文件路径dataset.yaml
- model: 预训练模型名称或路径
- epochs: 训练轮数
- imgsz: 输入图片大小
- batch: 批次大小
- project: 输出项目目录
- name: 实验名称
- valid_dir: 额外验证图片目录
"""
parser = argparse.ArgumentParser(description="Train YOLO model for digit recognition")
parser.add_argument("--data", type=Path, default=Path("yolo_dataset/dataset.yaml"), help="path to dataset yaml")
parser.add_argument("--model", type=str, default="yolov8n.pt", help="pretrained YOLO checkpoint")
parser.add_argument("--epochs", type=int, default=100, help="number of training epochs")
parser.add_argument("--imgsz", type=int, default=320, help="image size")
parser.add_argument("--batch", type=int, default=16, help="batch size")
parser.add_argument("--project", type=str, default="runs/digit_yolo", help="training output directory")
parser.add_argument("--name", type=str, default="exp", help="run name")
parser.add_argument(
"--valid-dir", type=Path, default=Path("valid"), help="directory with four-digit images for evaluation"
)
return parser.parse_args()
def main() -> None:
"""
主函数执行YOLO模型训练流程
完整流程:
1. 解析命令行参数
2. 加载YOLO预训练模型
3. 开始训练(自动保存检查点)
4. 训练完成后加载最佳模型
5. 在验证集上评估性能
6. 可选在valid文件夹上进行推理
训练输出:
- 每个epoch的训练和验证指标
- 混淆矩阵
- PR曲线
- 训练曲线图
- 最佳和最后模型权重
验证指标:
- mAP50: IoU=0.5时的mAP主要指标
- mAP50-95: IoU从0.5到0.95的平均mAP
- Precision: 精确率
- Recall: 召回率
- 每个类别数字0-9的性能
异常处理:
- FileNotFoundError: 数据集配置文件不存在
- RuntimeError: 训练失败或模型加载失败
"""
args = parse_args()
model = YOLO(args.model)
results = model.train(
data=str(args.data),
epochs=args.epochs,
imgsz=args.imgsz,
batch=args.batch,
project=args.project,
name=args.name,
exist_ok=True,
)
print("Training complete. Summary metrics:")
print(results)
best_ckpt = Path(results.save_dir) / "weights" / "best.pt"
if not best_ckpt.exists():
raise FileNotFoundError(f"Best checkpoint not found at {best_ckpt}")
# Validate on the validation split
model = YOLO(str(best_ckpt))
print("Running validation...")
val_metrics = model.val(data=str(args.data), imgsz=args.imgsz, project=args.project, name=f"{args.name}_val")
print(val_metrics)
# Inference on the valid folder
if args.valid_dir.exists():
print(f"Running inference on {args.valid_dir} ...")
model.predict(
source=str(args.valid_dir),
imgsz=args.imgsz,
save=True,
save_txt=True,
project=args.project,
name=f"{args.name}_valid",
)
print(f"Predictions saved to {Path(args.project) / f'{args.name}_valid'}")
else:
print(f"Valid directory {args.valid_dir} not found; skipping inference.")
if __name__ == "__main__":
main()