""" 完整的预处理+训练流程 步骤: 1. 预处理digit-validation图片 2. 预处理valid图片 3. 使用预处理后的数据准备YOLO数据集 4. 训练新模型 5. 在预处理后的valid上测试 Usage: python scripts/train_with_preprocessing.py --epochs 150 --method auto """ from __future__ import annotations import argparse import shutil import subprocess import sys from pathlib import Path def run_command(cmd: list[str], description: str, cwd: Path = None) -> None: """运行命令并显示进度""" print("\n" + "=" * 80) print(f"⚡ {description}") print("=" * 80) print(f"命令: {' '.join(cmd)}") print("-" * 80) result = subprocess.run(cmd, cwd=cwd or Path.cwd()) if result.returncode != 0: print(f"\n❌ 错误: {description} 失败") sys.exit(1) else: print(f"\n✅ {description} 成功完成") def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="预处理+训练完整流程") parser.add_argument( "--preprocess-method", type=str, default="auto", choices=["auto", "clahe", "binary", "denoise", "sharpen", "combined"], help="预处理方法(默认: auto)" ) parser.add_argument( "--epochs", type=int, default=150, help="训练轮数(默认150)" ) parser.add_argument( "--batch", type=int, default=16, help="批次大小(默认16)" ) parser.add_argument( "--model", type=str, default="yolov8n.pt", help="预训练模型(默认yolov8n.pt)" ) parser.add_argument( "--exp-name", type=str, default="exp_preprocessed", help="实验名称" ) parser.add_argument( "--skip-preprocess", action="store_true", help="跳过预处理步骤(如果已经预处理过)" ) parser.add_argument( "--preprocess-only", action="store_true", help="只做预处理,不训练模型" ) parser.add_argument( "--keep-color", action="store_true", help="保持彩色图片(默认转灰度)" ) return parser.parse_args() def main() -> None: args = parse_args() project_root = Path.cwd() print("🚀 开始预处理+训练完整流程") print("=" * 80) print(f"预处理方法: {args.preprocess_method}") print(f"训练轮数: {args.epochs}") print(f"模型: {args.model}") print(f"实验名称: {args.exp_name}") print("=" * 80) # 定义路径 digit_validation_input = project_root / "digit-validation" / "images" digit_validation_output = project_root / "digit-validation-processed" / "images" valid_input = project_root / "valid" valid_output = project_root / "valid-processed" # 步骤1: 预处理训练数据 if not args.skip_preprocess: # 预处理digit-validation run_command( [ "python", "scripts/preprocess_images.py", "--input", str(digit_validation_input), "--output", str(digit_validation_output), "--method", args.preprocess_method ] + (["--keep-color"] if args.keep_color else []), "步骤1.1: 预处理训练数据集(digit-validation)", cwd=project_root ) # 复制coco.json到预处理后的目录 processed_root = project_root / "digit-validation-processed" processed_root.mkdir(parents=True, exist_ok=True) coco_src = project_root / "digit-validation" / "coco.json" coco_dst = processed_root / "coco.json" if coco_src.exists(): shutil.copy2(coco_src, coco_dst) print(f"✓ 复制 coco.json 到 {coco_dst}") # 预处理valid数据 run_command( [ "python", "scripts/preprocess_images.py", "--input", str(valid_input), "--output", str(valid_output), "--method", args.preprocess_method ] + (["--keep-color"] if args.keep_color else []), "步骤1.2: 预处理验证数据集(valid)", cwd=project_root ) else: print("\n⏭️ 跳过预处理步骤(使用已有的预处理数据)") # 步骤2: 准备YOLO数据集(使用预处理后的图片) yolo_dataset_output = project_root / "yolo_dataset_preprocessed" run_command( [ "python", "scripts/prepare_yolo_dataset.py", "--root", "digit-validation-processed", "--out", str(yolo_dataset_output), "--val-ratio", "0.2", "--seed", "20240305" ], "步骤2: 准备YOLO数据集(基于预处理后的图片)", cwd=project_root ) # 步骤3: 训练模型 run_command( [ "python", "scripts/train_yolo.py", "--data", str(yolo_dataset_output / "dataset.yaml"), "--model", args.model, "--epochs", str(args.epochs), "--batch", str(args.batch), "--project", "runs/digit_yolo", "--name", args.exp_name ], "步骤3: 训练YOLO模型(使用预处理数据)", cwd=project_root ) # 步骤4: 在预处理后的valid数据上测试 best_model = project_root / "runs" / "digit_yolo" / args.exp_name / "weights" / "best.pt" run_command( [ "python", "scripts/predict_digits_improved.py", "--model", str(best_model), "--source", str(valid_output), "--conf", "0.2", "--output", f"results/predictions_{args.exp_name}.txt", "--save-vis" ], "步骤4: 在预处理后的valid数据上测试", cwd=project_root ) # 也在原始valid数据上测试做对比 run_command( [ "python", "scripts/predict_digits_improved.py", "--model", str(best_model), "--source", str(valid_input), "--conf", "0.2", "--output", f"results/predictions_{args.exp_name}_original.txt", ], "步骤5: 在原始valid数据上测试(对比)", cwd=project_root ) print("\n" + "=" * 80) print("🎉 完整流程完成!") print("=" * 80) print("\n📊 查看结果:") print(f" - 预处理后的训练数据: {digit_validation_output}") print(f" - 预处理后的验证数据: {valid_output}") print(f" - YOLO数据集: {yolo_dataset_output}") print(f" - 训练模型: {best_model}") print(f" - 识别结果(预处理数据): results/predictions_{args.exp_name}.txt") print(f" - 识别结果(原始数据): results/predictions_{args.exp_name}_original.txt") print(f" - 可视化结果: results/visualizations/") print() if __name__ == "__main__": main()