Files
digit-cracker/scripts/train_with_preprocessing.py
2025-10-30 15:40:56 +08:00

225 lines
6.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
完整的预处理+训练流程
步骤:
1. 预处理digit-validation图片
2. 预处理valid图片
3. 使用预处理后的数据准备YOLO数据集
4. 训练新模型
5. 在预处理后的valid上测试
Usage:
python scripts/train_with_preprocessing.py --epochs 150 --method auto
"""
from __future__ import annotations
import argparse
import shutil
import subprocess
import sys
from pathlib import Path
def run_command(cmd: list[str], description: str, cwd: Path = None) -> None:
"""运行命令并显示进度"""
print("\n" + "=" * 80)
print(f"{description}")
print("=" * 80)
print(f"命令: {' '.join(cmd)}")
print("-" * 80)
result = subprocess.run(cmd, cwd=cwd or Path.cwd())
if result.returncode != 0:
print(f"\n❌ 错误: {description} 失败")
sys.exit(1)
else:
print(f"\n{description} 成功完成")
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="预处理+训练完整流程")
parser.add_argument(
"--preprocess-method",
type=str,
default="auto",
choices=["auto", "clahe", "binary", "denoise", "sharpen", "combined"],
help="预处理方法(默认: auto"
)
parser.add_argument(
"--epochs",
type=int,
default=150,
help="训练轮数默认150"
)
parser.add_argument(
"--batch",
type=int,
default=16,
help="批次大小默认16"
)
parser.add_argument(
"--model",
type=str,
default="yolov8n.pt",
help="预训练模型默认yolov8n.pt"
)
parser.add_argument(
"--exp-name",
type=str,
default="exp_preprocessed",
help="实验名称"
)
parser.add_argument(
"--skip-preprocess",
action="store_true",
help="跳过预处理步骤(如果已经预处理过)"
)
parser.add_argument(
"--preprocess-only",
action="store_true",
help="只做预处理,不训练模型"
)
parser.add_argument(
"--keep-color",
action="store_true",
help="保持彩色图片(默认转灰度)"
)
return parser.parse_args()
def main() -> None:
args = parse_args()
project_root = Path.cwd()
print("🚀 开始预处理+训练完整流程")
print("=" * 80)
print(f"预处理方法: {args.preprocess_method}")
print(f"训练轮数: {args.epochs}")
print(f"模型: {args.model}")
print(f"实验名称: {args.exp_name}")
print("=" * 80)
# 定义路径
digit_validation_input = project_root / "digit-validation" / "images"
digit_validation_output = project_root / "digit-validation-processed" / "images"
valid_input = project_root / "valid"
valid_output = project_root / "valid-processed"
# 步骤1: 预处理训练数据
if not args.skip_preprocess:
# 预处理digit-validation
run_command(
[
"python", "scripts/preprocess_images.py",
"--input", str(digit_validation_input),
"--output", str(digit_validation_output),
"--method", args.preprocess_method
] + (["--keep-color"] if args.keep_color else []),
"步骤1.1: 预处理训练数据集digit-validation",
cwd=project_root
)
# 复制coco.json到预处理后的目录
processed_root = project_root / "digit-validation-processed"
processed_root.mkdir(parents=True, exist_ok=True)
coco_src = project_root / "digit-validation" / "coco.json"
coco_dst = processed_root / "coco.json"
if coco_src.exists():
shutil.copy2(coco_src, coco_dst)
print(f"✓ 复制 coco.json 到 {coco_dst}")
# 预处理valid数据
run_command(
[
"python", "scripts/preprocess_images.py",
"--input", str(valid_input),
"--output", str(valid_output),
"--method", args.preprocess_method
] + (["--keep-color"] if args.keep_color else []),
"步骤1.2: 预处理验证数据集valid",
cwd=project_root
)
else:
print("\n⏭️ 跳过预处理步骤(使用已有的预处理数据)")
# 步骤2: 准备YOLO数据集使用预处理后的图片
yolo_dataset_output = project_root / "yolo_dataset_preprocessed"
run_command(
[
"python", "scripts/prepare_yolo_dataset.py",
"--root", "digit-validation-processed",
"--out", str(yolo_dataset_output),
"--val-ratio", "0.2",
"--seed", "20240305"
],
"步骤2: 准备YOLO数据集基于预处理后的图片",
cwd=project_root
)
# 步骤3: 训练模型
run_command(
[
"python", "scripts/train_yolo.py",
"--data", str(yolo_dataset_output / "dataset.yaml"),
"--model", args.model,
"--epochs", str(args.epochs),
"--batch", str(args.batch),
"--project", "runs/digit_yolo",
"--name", args.exp_name
],
"步骤3: 训练YOLO模型使用预处理数据",
cwd=project_root
)
# 步骤4: 在预处理后的valid数据上测试
best_model = project_root / "runs" / "digit_yolo" / args.exp_name / "weights" / "best.pt"
run_command(
[
"python", "scripts/predict_digits_improved.py",
"--model", str(best_model),
"--source", str(valid_output),
"--conf", "0.2",
"--output", f"results/predictions_{args.exp_name}.txt",
"--save-vis"
],
"步骤4: 在预处理后的valid数据上测试",
cwd=project_root
)
# 也在原始valid数据上测试做对比
run_command(
[
"python", "scripts/predict_digits_improved.py",
"--model", str(best_model),
"--source", str(valid_input),
"--conf", "0.2",
"--output", f"results/predictions_{args.exp_name}_original.txt",
],
"步骤5: 在原始valid数据上测试对比",
cwd=project_root
)
print("\n" + "=" * 80)
print("🎉 完整流程完成!")
print("=" * 80)
print("\n📊 查看结果:")
print(f" - 预处理后的训练数据: {digit_validation_output}")
print(f" - 预处理后的验证数据: {valid_output}")
print(f" - YOLO数据集: {yolo_dataset_output}")
print(f" - 训练模型: {best_model}")
print(f" - 识别结果(预处理数据): results/predictions_{args.exp_name}.txt")
print(f" - 识别结果(原始数据): results/predictions_{args.exp_name}_original.txt")
print(f" - 可视化结果: results/visualizations/")
print()
if __name__ == "__main__":
main()