225 lines
6.9 KiB
Python
225 lines
6.9 KiB
Python
"""
|
||
完整的预处理+训练流程
|
||
|
||
步骤:
|
||
1. 预处理digit-validation图片
|
||
2. 预处理valid图片
|
||
3. 使用预处理后的数据准备YOLO数据集
|
||
4. 训练新模型
|
||
5. 在预处理后的valid上测试
|
||
|
||
Usage:
|
||
python scripts/train_with_preprocessing.py --epochs 150 --method auto
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import shutil
|
||
import subprocess
|
||
import sys
|
||
from pathlib import Path
|
||
|
||
|
||
def run_command(cmd: list[str], description: str, cwd: Path = None) -> None:
|
||
"""运行命令并显示进度"""
|
||
print("\n" + "=" * 80)
|
||
print(f"⚡ {description}")
|
||
print("=" * 80)
|
||
print(f"命令: {' '.join(cmd)}")
|
||
print("-" * 80)
|
||
|
||
result = subprocess.run(cmd, cwd=cwd or Path.cwd())
|
||
|
||
if result.returncode != 0:
|
||
print(f"\n❌ 错误: {description} 失败")
|
||
sys.exit(1)
|
||
else:
|
||
print(f"\n✅ {description} 成功完成")
|
||
|
||
|
||
def parse_args() -> argparse.Namespace:
|
||
parser = argparse.ArgumentParser(description="预处理+训练完整流程")
|
||
parser.add_argument(
|
||
"--preprocess-method",
|
||
type=str,
|
||
default="auto",
|
||
choices=["auto", "clahe", "binary", "denoise", "sharpen", "combined"],
|
||
help="预处理方法(默认: auto)"
|
||
)
|
||
parser.add_argument(
|
||
"--epochs",
|
||
type=int,
|
||
default=150,
|
||
help="训练轮数(默认150)"
|
||
)
|
||
parser.add_argument(
|
||
"--batch",
|
||
type=int,
|
||
default=16,
|
||
help="批次大小(默认16)"
|
||
)
|
||
parser.add_argument(
|
||
"--model",
|
||
type=str,
|
||
default="yolov8n.pt",
|
||
help="预训练模型(默认yolov8n.pt)"
|
||
)
|
||
parser.add_argument(
|
||
"--exp-name",
|
||
type=str,
|
||
default="exp_preprocessed",
|
||
help="实验名称"
|
||
)
|
||
parser.add_argument(
|
||
"--skip-preprocess",
|
||
action="store_true",
|
||
help="跳过预处理步骤(如果已经预处理过)"
|
||
)
|
||
parser.add_argument(
|
||
"--preprocess-only",
|
||
action="store_true",
|
||
help="只做预处理,不训练模型"
|
||
)
|
||
parser.add_argument(
|
||
"--keep-color",
|
||
action="store_true",
|
||
help="保持彩色图片(默认转灰度)"
|
||
)
|
||
return parser.parse_args()
|
||
|
||
|
||
def main() -> None:
|
||
args = parse_args()
|
||
|
||
project_root = Path.cwd()
|
||
|
||
print("🚀 开始预处理+训练完整流程")
|
||
print("=" * 80)
|
||
print(f"预处理方法: {args.preprocess_method}")
|
||
print(f"训练轮数: {args.epochs}")
|
||
print(f"模型: {args.model}")
|
||
print(f"实验名称: {args.exp_name}")
|
||
print("=" * 80)
|
||
|
||
# 定义路径
|
||
digit_validation_input = project_root / "digit-validation" / "images"
|
||
digit_validation_output = project_root / "digit-validation-processed" / "images"
|
||
valid_input = project_root / "valid"
|
||
valid_output = project_root / "valid-processed"
|
||
|
||
# 步骤1: 预处理训练数据
|
||
if not args.skip_preprocess:
|
||
# 预处理digit-validation
|
||
run_command(
|
||
[
|
||
"python", "scripts/preprocess_images.py",
|
||
"--input", str(digit_validation_input),
|
||
"--output", str(digit_validation_output),
|
||
"--method", args.preprocess_method
|
||
] + (["--keep-color"] if args.keep_color else []),
|
||
"步骤1.1: 预处理训练数据集(digit-validation)",
|
||
cwd=project_root
|
||
)
|
||
|
||
# 复制coco.json到预处理后的目录
|
||
processed_root = project_root / "digit-validation-processed"
|
||
processed_root.mkdir(parents=True, exist_ok=True)
|
||
|
||
coco_src = project_root / "digit-validation" / "coco.json"
|
||
coco_dst = processed_root / "coco.json"
|
||
|
||
if coco_src.exists():
|
||
shutil.copy2(coco_src, coco_dst)
|
||
print(f"✓ 复制 coco.json 到 {coco_dst}")
|
||
|
||
# 预处理valid数据
|
||
run_command(
|
||
[
|
||
"python", "scripts/preprocess_images.py",
|
||
"--input", str(valid_input),
|
||
"--output", str(valid_output),
|
||
"--method", args.preprocess_method
|
||
] + (["--keep-color"] if args.keep_color else []),
|
||
"步骤1.2: 预处理验证数据集(valid)",
|
||
cwd=project_root
|
||
)
|
||
else:
|
||
print("\n⏭️ 跳过预处理步骤(使用已有的预处理数据)")
|
||
|
||
# 步骤2: 准备YOLO数据集(使用预处理后的图片)
|
||
yolo_dataset_output = project_root / "yolo_dataset_preprocessed"
|
||
|
||
run_command(
|
||
[
|
||
"python", "scripts/prepare_yolo_dataset.py",
|
||
"--root", "digit-validation-processed",
|
||
"--out", str(yolo_dataset_output),
|
||
"--val-ratio", "0.2",
|
||
"--seed", "20240305"
|
||
],
|
||
"步骤2: 准备YOLO数据集(基于预处理后的图片)",
|
||
cwd=project_root
|
||
)
|
||
|
||
# 步骤3: 训练模型
|
||
run_command(
|
||
[
|
||
"python", "scripts/train_yolo.py",
|
||
"--data", str(yolo_dataset_output / "dataset.yaml"),
|
||
"--model", args.model,
|
||
"--epochs", str(args.epochs),
|
||
"--batch", str(args.batch),
|
||
"--project", "runs/digit_yolo",
|
||
"--name", args.exp_name
|
||
],
|
||
"步骤3: 训练YOLO模型(使用预处理数据)",
|
||
cwd=project_root
|
||
)
|
||
|
||
# 步骤4: 在预处理后的valid数据上测试
|
||
best_model = project_root / "runs" / "digit_yolo" / args.exp_name / "weights" / "best.pt"
|
||
|
||
run_command(
|
||
[
|
||
"python", "scripts/predict_digits_improved.py",
|
||
"--model", str(best_model),
|
||
"--source", str(valid_output),
|
||
"--conf", "0.2",
|
||
"--output", f"results/predictions_{args.exp_name}.txt",
|
||
"--save-vis"
|
||
],
|
||
"步骤4: 在预处理后的valid数据上测试",
|
||
cwd=project_root
|
||
)
|
||
|
||
# 也在原始valid数据上测试做对比
|
||
run_command(
|
||
[
|
||
"python", "scripts/predict_digits_improved.py",
|
||
"--model", str(best_model),
|
||
"--source", str(valid_input),
|
||
"--conf", "0.2",
|
||
"--output", f"results/predictions_{args.exp_name}_original.txt",
|
||
],
|
||
"步骤5: 在原始valid数据上测试(对比)",
|
||
cwd=project_root
|
||
)
|
||
|
||
print("\n" + "=" * 80)
|
||
print("🎉 完整流程完成!")
|
||
print("=" * 80)
|
||
print("\n📊 查看结果:")
|
||
print(f" - 预处理后的训练数据: {digit_validation_output}")
|
||
print(f" - 预处理后的验证数据: {valid_output}")
|
||
print(f" - YOLO数据集: {yolo_dataset_output}")
|
||
print(f" - 训练模型: {best_model}")
|
||
print(f" - 识别结果(预处理数据): results/predictions_{args.exp_name}.txt")
|
||
print(f" - 识别结果(原始数据): results/predictions_{args.exp_name}_original.txt")
|
||
print(f" - 可视化结果: results/visualizations/")
|
||
print()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|