first commit
This commit is contained in:
119
scripts/run_all.py
Normal file
119
scripts/run_all.py
Normal file
@@ -0,0 +1,119 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
完整的YOLO数字识别流程
|
||||
包括:数据准备、模型训练、模型验证和推理
|
||||
|
||||
Usage:
|
||||
python scripts/run_all.py [--skip-train] [--skip-predict]
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def run_command(cmd: list[str], description: str) -> None:
|
||||
"""运行命令并显示进度"""
|
||||
print("\n" + "=" * 80)
|
||||
print(f"⚡ {description}")
|
||||
print("=" * 80)
|
||||
print(f"命令: {' '.join(cmd)}")
|
||||
print("-" * 80)
|
||||
|
||||
result = subprocess.run(cmd, cwd=Path(__file__).parent.parent)
|
||||
|
||||
if result.returncode != 0:
|
||||
print(f"\n❌ 错误: {description} 失败")
|
||||
sys.exit(1)
|
||||
else:
|
||||
print(f"\n✅ {description} 成功完成")
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="运行完整的YOLO数字识别流程")
|
||||
parser.add_argument(
|
||||
"--skip-prepare",
|
||||
action="store_true",
|
||||
help="跳过数据准备步骤(如果已经准备好数据)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-train",
|
||||
action="store_true",
|
||||
help="跳过训练步骤(如果模型已训练)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-predict",
|
||||
action="store_true",
|
||||
help="跳过预测步骤"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--epochs",
|
||||
type=int,
|
||||
default=100,
|
||||
help="训练轮数(默认100)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch",
|
||||
type=int,
|
||||
default=16,
|
||||
help="批次大小(默认16)"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
|
||||
print("🚀 开始YOLO数字识别完整流程")
|
||||
print("=" * 80)
|
||||
|
||||
# 步骤1:准备数据集
|
||||
if not args.skip_prepare:
|
||||
run_command(
|
||||
["python", "scripts/prepare_yolo_dataset.py"],
|
||||
"步骤1: 准备YOLO数据集"
|
||||
)
|
||||
else:
|
||||
print("\n⏭️ 跳过数据准备步骤")
|
||||
|
||||
# 步骤2:训练模型
|
||||
if not args.skip_train:
|
||||
run_command(
|
||||
[
|
||||
"python", "scripts/train_yolo.py",
|
||||
"--epochs", str(args.epochs),
|
||||
"--batch", str(args.batch),
|
||||
"--name", "exp1"
|
||||
],
|
||||
"步骤2: 训练YOLO模型"
|
||||
)
|
||||
else:
|
||||
print("\n⏭️ 跳过训练步骤")
|
||||
|
||||
# 步骤3:在valid文件夹上进行预测
|
||||
if not args.skip_predict:
|
||||
run_command(
|
||||
[
|
||||
"python", "scripts/predict_digits.py",
|
||||
"--save-vis"
|
||||
],
|
||||
"步骤3: 识别valid文件夹中的4位数字"
|
||||
)
|
||||
else:
|
||||
print("\n⏭️ 跳过预测步骤")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("🎉 所有步骤完成!")
|
||||
print("=" * 80)
|
||||
print("\n📊 查看结果:")
|
||||
print(" - 训练结果: runs/digit_yolo/exp1/")
|
||||
print(" - 预测结果: results/predictions.txt")
|
||||
print(" - 可视化结果: results/visualizations/")
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user