120 lines
3.0 KiB
Python
120 lines
3.0 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
完整的YOLO数字识别流程
|
||
包括:数据准备、模型训练、模型验证和推理
|
||
|
||
Usage:
|
||
python scripts/run_all.py [--skip-train] [--skip-predict]
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import subprocess
|
||
import sys
|
||
from pathlib import Path
|
||
|
||
|
||
def run_command(cmd: list[str], description: str) -> None:
|
||
"""运行命令并显示进度"""
|
||
print("\n" + "=" * 80)
|
||
print(f"⚡ {description}")
|
||
print("=" * 80)
|
||
print(f"命令: {' '.join(cmd)}")
|
||
print("-" * 80)
|
||
|
||
result = subprocess.run(cmd, cwd=Path(__file__).parent.parent)
|
||
|
||
if result.returncode != 0:
|
||
print(f"\n❌ 错误: {description} 失败")
|
||
sys.exit(1)
|
||
else:
|
||
print(f"\n✅ {description} 成功完成")
|
||
|
||
|
||
def parse_args() -> argparse.Namespace:
|
||
parser = argparse.ArgumentParser(description="运行完整的YOLO数字识别流程")
|
||
parser.add_argument(
|
||
"--skip-prepare",
|
||
action="store_true",
|
||
help="跳过数据准备步骤(如果已经准备好数据)"
|
||
)
|
||
parser.add_argument(
|
||
"--skip-train",
|
||
action="store_true",
|
||
help="跳过训练步骤(如果模型已训练)"
|
||
)
|
||
parser.add_argument(
|
||
"--skip-predict",
|
||
action="store_true",
|
||
help="跳过预测步骤"
|
||
)
|
||
parser.add_argument(
|
||
"--epochs",
|
||
type=int,
|
||
default=100,
|
||
help="训练轮数(默认100)"
|
||
)
|
||
parser.add_argument(
|
||
"--batch",
|
||
type=int,
|
||
default=16,
|
||
help="批次大小(默认16)"
|
||
)
|
||
return parser.parse_args()
|
||
|
||
|
||
def main() -> None:
|
||
args = parse_args()
|
||
|
||
print("🚀 开始YOLO数字识别完整流程")
|
||
print("=" * 80)
|
||
|
||
# 步骤1:准备数据集
|
||
if not args.skip_prepare:
|
||
run_command(
|
||
["python", "scripts/prepare_yolo_dataset.py"],
|
||
"步骤1: 准备YOLO数据集"
|
||
)
|
||
else:
|
||
print("\n⏭️ 跳过数据准备步骤")
|
||
|
||
# 步骤2:训练模型
|
||
if not args.skip_train:
|
||
run_command(
|
||
[
|
||
"python", "scripts/train_yolo.py",
|
||
"--epochs", str(args.epochs),
|
||
"--batch", str(args.batch),
|
||
"--name", "exp1"
|
||
],
|
||
"步骤2: 训练YOLO模型"
|
||
)
|
||
else:
|
||
print("\n⏭️ 跳过训练步骤")
|
||
|
||
# 步骤3:在valid文件夹上进行预测
|
||
if not args.skip_predict:
|
||
run_command(
|
||
[
|
||
"python", "scripts/predict_digits.py",
|
||
"--save-vis"
|
||
],
|
||
"步骤3: 识别valid文件夹中的4位数字"
|
||
)
|
||
else:
|
||
print("\n⏭️ 跳过预测步骤")
|
||
|
||
print("\n" + "=" * 80)
|
||
print("🎉 所有步骤完成!")
|
||
print("=" * 80)
|
||
print("\n📊 查看结果:")
|
||
print(" - 训练结果: runs/digit_yolo/exp1/")
|
||
print(" - 预测结果: results/predictions.txt")
|
||
print(" - 可视化结果: results/visualizations/")
|
||
print()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|