first commit
This commit is contained in:
224
scripts/train_with_preprocessing.py
Normal file
224
scripts/train_with_preprocessing.py
Normal file
@@ -0,0 +1,224 @@
|
||||
"""
|
||||
完整的预处理+训练流程
|
||||
|
||||
步骤:
|
||||
1. 预处理digit-validation图片
|
||||
2. 预处理valid图片
|
||||
3. 使用预处理后的数据准备YOLO数据集
|
||||
4. 训练新模型
|
||||
5. 在预处理后的valid上测试
|
||||
|
||||
Usage:
|
||||
python scripts/train_with_preprocessing.py --epochs 150 --method auto
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def run_command(cmd: list[str], description: str, cwd: Path = None) -> None:
|
||||
"""运行命令并显示进度"""
|
||||
print("\n" + "=" * 80)
|
||||
print(f"⚡ {description}")
|
||||
print("=" * 80)
|
||||
print(f"命令: {' '.join(cmd)}")
|
||||
print("-" * 80)
|
||||
|
||||
result = subprocess.run(cmd, cwd=cwd or Path.cwd())
|
||||
|
||||
if result.returncode != 0:
|
||||
print(f"\n❌ 错误: {description} 失败")
|
||||
sys.exit(1)
|
||||
else:
|
||||
print(f"\n✅ {description} 成功完成")
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="预处理+训练完整流程")
|
||||
parser.add_argument(
|
||||
"--preprocess-method",
|
||||
type=str,
|
||||
default="auto",
|
||||
choices=["auto", "clahe", "binary", "denoise", "sharpen", "combined"],
|
||||
help="预处理方法(默认: auto)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--epochs",
|
||||
type=int,
|
||||
default=150,
|
||||
help="训练轮数(默认150)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch",
|
||||
type=int,
|
||||
default=16,
|
||||
help="批次大小(默认16)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="yolov8n.pt",
|
||||
help="预训练模型(默认yolov8n.pt)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--exp-name",
|
||||
type=str,
|
||||
default="exp_preprocessed",
|
||||
help="实验名称"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-preprocess",
|
||||
action="store_true",
|
||||
help="跳过预处理步骤(如果已经预处理过)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--preprocess-only",
|
||||
action="store_true",
|
||||
help="只做预处理,不训练模型"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--keep-color",
|
||||
action="store_true",
|
||||
help="保持彩色图片(默认转灰度)"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
|
||||
project_root = Path.cwd()
|
||||
|
||||
print("🚀 开始预处理+训练完整流程")
|
||||
print("=" * 80)
|
||||
print(f"预处理方法: {args.preprocess_method}")
|
||||
print(f"训练轮数: {args.epochs}")
|
||||
print(f"模型: {args.model}")
|
||||
print(f"实验名称: {args.exp_name}")
|
||||
print("=" * 80)
|
||||
|
||||
# 定义路径
|
||||
digit_validation_input = project_root / "digit-validation" / "images"
|
||||
digit_validation_output = project_root / "digit-validation-processed" / "images"
|
||||
valid_input = project_root / "valid"
|
||||
valid_output = project_root / "valid-processed"
|
||||
|
||||
# 步骤1: 预处理训练数据
|
||||
if not args.skip_preprocess:
|
||||
# 预处理digit-validation
|
||||
run_command(
|
||||
[
|
||||
"python", "scripts/preprocess_images.py",
|
||||
"--input", str(digit_validation_input),
|
||||
"--output", str(digit_validation_output),
|
||||
"--method", args.preprocess_method
|
||||
] + (["--keep-color"] if args.keep_color else []),
|
||||
"步骤1.1: 预处理训练数据集(digit-validation)",
|
||||
cwd=project_root
|
||||
)
|
||||
|
||||
# 复制coco.json到预处理后的目录
|
||||
processed_root = project_root / "digit-validation-processed"
|
||||
processed_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
coco_src = project_root / "digit-validation" / "coco.json"
|
||||
coco_dst = processed_root / "coco.json"
|
||||
|
||||
if coco_src.exists():
|
||||
shutil.copy2(coco_src, coco_dst)
|
||||
print(f"✓ 复制 coco.json 到 {coco_dst}")
|
||||
|
||||
# 预处理valid数据
|
||||
run_command(
|
||||
[
|
||||
"python", "scripts/preprocess_images.py",
|
||||
"--input", str(valid_input),
|
||||
"--output", str(valid_output),
|
||||
"--method", args.preprocess_method
|
||||
] + (["--keep-color"] if args.keep_color else []),
|
||||
"步骤1.2: 预处理验证数据集(valid)",
|
||||
cwd=project_root
|
||||
)
|
||||
else:
|
||||
print("\n⏭️ 跳过预处理步骤(使用已有的预处理数据)")
|
||||
|
||||
# 步骤2: 准备YOLO数据集(使用预处理后的图片)
|
||||
yolo_dataset_output = project_root / "yolo_dataset_preprocessed"
|
||||
|
||||
run_command(
|
||||
[
|
||||
"python", "scripts/prepare_yolo_dataset.py",
|
||||
"--root", "digit-validation-processed",
|
||||
"--out", str(yolo_dataset_output),
|
||||
"--val-ratio", "0.2",
|
||||
"--seed", "20240305"
|
||||
],
|
||||
"步骤2: 准备YOLO数据集(基于预处理后的图片)",
|
||||
cwd=project_root
|
||||
)
|
||||
|
||||
# 步骤3: 训练模型
|
||||
run_command(
|
||||
[
|
||||
"python", "scripts/train_yolo.py",
|
||||
"--data", str(yolo_dataset_output / "dataset.yaml"),
|
||||
"--model", args.model,
|
||||
"--epochs", str(args.epochs),
|
||||
"--batch", str(args.batch),
|
||||
"--project", "runs/digit_yolo",
|
||||
"--name", args.exp_name
|
||||
],
|
||||
"步骤3: 训练YOLO模型(使用预处理数据)",
|
||||
cwd=project_root
|
||||
)
|
||||
|
||||
# 步骤4: 在预处理后的valid数据上测试
|
||||
best_model = project_root / "runs" / "digit_yolo" / args.exp_name / "weights" / "best.pt"
|
||||
|
||||
run_command(
|
||||
[
|
||||
"python", "scripts/predict_digits_improved.py",
|
||||
"--model", str(best_model),
|
||||
"--source", str(valid_output),
|
||||
"--conf", "0.2",
|
||||
"--output", f"results/predictions_{args.exp_name}.txt",
|
||||
"--save-vis"
|
||||
],
|
||||
"步骤4: 在预处理后的valid数据上测试",
|
||||
cwd=project_root
|
||||
)
|
||||
|
||||
# 也在原始valid数据上测试做对比
|
||||
run_command(
|
||||
[
|
||||
"python", "scripts/predict_digits_improved.py",
|
||||
"--model", str(best_model),
|
||||
"--source", str(valid_input),
|
||||
"--conf", "0.2",
|
||||
"--output", f"results/predictions_{args.exp_name}_original.txt",
|
||||
],
|
||||
"步骤5: 在原始valid数据上测试(对比)",
|
||||
cwd=project_root
|
||||
)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("🎉 完整流程完成!")
|
||||
print("=" * 80)
|
||||
print("\n📊 查看结果:")
|
||||
print(f" - 预处理后的训练数据: {digit_validation_output}")
|
||||
print(f" - 预处理后的验证数据: {valid_output}")
|
||||
print(f" - YOLO数据集: {yolo_dataset_output}")
|
||||
print(f" - 训练模型: {best_model}")
|
||||
print(f" - 识别结果(预处理数据): results/predictions_{args.exp_name}.txt")
|
||||
print(f" - 识别结果(原始数据): results/predictions_{args.exp_name}_original.txt")
|
||||
print(f" - 可视化结果: results/visualizations/")
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user