Files
digit-cracker/scripts/run_all.py
2025-10-30 15:40:56 +08:00

120 lines
3.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
完整的YOLO数字识别流程
包括:数据准备、模型训练、模型验证和推理
Usage:
python scripts/run_all.py [--skip-train] [--skip-predict]
"""
from __future__ import annotations
import argparse
import subprocess
import sys
from pathlib import Path
def run_command(cmd: list[str], description: str) -> None:
"""运行命令并显示进度"""
print("\n" + "=" * 80)
print(f"{description}")
print("=" * 80)
print(f"命令: {' '.join(cmd)}")
print("-" * 80)
result = subprocess.run(cmd, cwd=Path(__file__).parent.parent)
if result.returncode != 0:
print(f"\n❌ 错误: {description} 失败")
sys.exit(1)
else:
print(f"\n{description} 成功完成")
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="运行完整的YOLO数字识别流程")
parser.add_argument(
"--skip-prepare",
action="store_true",
help="跳过数据准备步骤(如果已经准备好数据)"
)
parser.add_argument(
"--skip-train",
action="store_true",
help="跳过训练步骤(如果模型已训练)"
)
parser.add_argument(
"--skip-predict",
action="store_true",
help="跳过预测步骤"
)
parser.add_argument(
"--epochs",
type=int,
default=100,
help="训练轮数默认100"
)
parser.add_argument(
"--batch",
type=int,
default=16,
help="批次大小默认16"
)
return parser.parse_args()
def main() -> None:
args = parse_args()
print("🚀 开始YOLO数字识别完整流程")
print("=" * 80)
# 步骤1准备数据集
if not args.skip_prepare:
run_command(
["python", "scripts/prepare_yolo_dataset.py"],
"步骤1: 准备YOLO数据集"
)
else:
print("\n⏭️ 跳过数据准备步骤")
# 步骤2训练模型
if not args.skip_train:
run_command(
[
"python", "scripts/train_yolo.py",
"--epochs", str(args.epochs),
"--batch", str(args.batch),
"--name", "exp1"
],
"步骤2: 训练YOLO模型"
)
else:
print("\n⏭️ 跳过训练步骤")
# 步骤3在valid文件夹上进行预测
if not args.skip_predict:
run_command(
[
"python", "scripts/predict_digits.py",
"--save-vis"
],
"步骤3: 识别valid文件夹中的4位数字"
)
else:
print("\n⏭️ 跳过预测步骤")
print("\n" + "=" * 80)
print("🎉 所有步骤完成!")
print("=" * 80)
print("\n📊 查看结果:")
print(" - 训练结果: runs/digit_yolo/exp1/")
print(" - 预测结果: results/predictions.txt")
print(" - 可视化结果: results/visualizations/")
print()
if __name__ == "__main__":
main()