first commit

This commit is contained in:
douboer
2025-10-30 15:40:56 +08:00
parent fe4a3e7cbf
commit 2fb4b22328
344 changed files with 8595 additions and 567 deletions

View File

@@ -0,0 +1,224 @@
"""
完整的预处理+训练流程
步骤:
1. 预处理digit-validation图片
2. 预处理valid图片
3. 使用预处理后的数据准备YOLO数据集
4. 训练新模型
5. 在预处理后的valid上测试
Usage:
python scripts/train_with_preprocessing.py --epochs 150 --method auto
"""
from __future__ import annotations
import argparse
import shutil
import subprocess
import sys
from pathlib import Path
def run_command(cmd: list[str], description: str, cwd: Path = None) -> None:
"""运行命令并显示进度"""
print("\n" + "=" * 80)
print(f"{description}")
print("=" * 80)
print(f"命令: {' '.join(cmd)}")
print("-" * 80)
result = subprocess.run(cmd, cwd=cwd or Path.cwd())
if result.returncode != 0:
print(f"\n❌ 错误: {description} 失败")
sys.exit(1)
else:
print(f"\n{description} 成功完成")
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="预处理+训练完整流程")
parser.add_argument(
"--preprocess-method",
type=str,
default="auto",
choices=["auto", "clahe", "binary", "denoise", "sharpen", "combined"],
help="预处理方法(默认: auto"
)
parser.add_argument(
"--epochs",
type=int,
default=150,
help="训练轮数默认150"
)
parser.add_argument(
"--batch",
type=int,
default=16,
help="批次大小默认16"
)
parser.add_argument(
"--model",
type=str,
default="yolov8n.pt",
help="预训练模型默认yolov8n.pt"
)
parser.add_argument(
"--exp-name",
type=str,
default="exp_preprocessed",
help="实验名称"
)
parser.add_argument(
"--skip-preprocess",
action="store_true",
help="跳过预处理步骤(如果已经预处理过)"
)
parser.add_argument(
"--preprocess-only",
action="store_true",
help="只做预处理,不训练模型"
)
parser.add_argument(
"--keep-color",
action="store_true",
help="保持彩色图片(默认转灰度)"
)
return parser.parse_args()
def main() -> None:
args = parse_args()
project_root = Path.cwd()
print("🚀 开始预处理+训练完整流程")
print("=" * 80)
print(f"预处理方法: {args.preprocess_method}")
print(f"训练轮数: {args.epochs}")
print(f"模型: {args.model}")
print(f"实验名称: {args.exp_name}")
print("=" * 80)
# 定义路径
digit_validation_input = project_root / "digit-validation" / "images"
digit_validation_output = project_root / "digit-validation-processed" / "images"
valid_input = project_root / "valid"
valid_output = project_root / "valid-processed"
# 步骤1: 预处理训练数据
if not args.skip_preprocess:
# 预处理digit-validation
run_command(
[
"python", "scripts/preprocess_images.py",
"--input", str(digit_validation_input),
"--output", str(digit_validation_output),
"--method", args.preprocess_method
] + (["--keep-color"] if args.keep_color else []),
"步骤1.1: 预处理训练数据集digit-validation",
cwd=project_root
)
# 复制coco.json到预处理后的目录
processed_root = project_root / "digit-validation-processed"
processed_root.mkdir(parents=True, exist_ok=True)
coco_src = project_root / "digit-validation" / "coco.json"
coco_dst = processed_root / "coco.json"
if coco_src.exists():
shutil.copy2(coco_src, coco_dst)
print(f"✓ 复制 coco.json 到 {coco_dst}")
# 预处理valid数据
run_command(
[
"python", "scripts/preprocess_images.py",
"--input", str(valid_input),
"--output", str(valid_output),
"--method", args.preprocess_method
] + (["--keep-color"] if args.keep_color else []),
"步骤1.2: 预处理验证数据集valid",
cwd=project_root
)
else:
print("\n⏭️ 跳过预处理步骤(使用已有的预处理数据)")
# 步骤2: 准备YOLO数据集使用预处理后的图片
yolo_dataset_output = project_root / "yolo_dataset_preprocessed"
run_command(
[
"python", "scripts/prepare_yolo_dataset.py",
"--root", "digit-validation-processed",
"--out", str(yolo_dataset_output),
"--val-ratio", "0.2",
"--seed", "20240305"
],
"步骤2: 准备YOLO数据集基于预处理后的图片",
cwd=project_root
)
# 步骤3: 训练模型
run_command(
[
"python", "scripts/train_yolo.py",
"--data", str(yolo_dataset_output / "dataset.yaml"),
"--model", args.model,
"--epochs", str(args.epochs),
"--batch", str(args.batch),
"--project", "runs/digit_yolo",
"--name", args.exp_name
],
"步骤3: 训练YOLO模型使用预处理数据",
cwd=project_root
)
# 步骤4: 在预处理后的valid数据上测试
best_model = project_root / "runs" / "digit_yolo" / args.exp_name / "weights" / "best.pt"
run_command(
[
"python", "scripts/predict_digits_improved.py",
"--model", str(best_model),
"--source", str(valid_output),
"--conf", "0.2",
"--output", f"results/predictions_{args.exp_name}.txt",
"--save-vis"
],
"步骤4: 在预处理后的valid数据上测试",
cwd=project_root
)
# 也在原始valid数据上测试做对比
run_command(
[
"python", "scripts/predict_digits_improved.py",
"--model", str(best_model),
"--source", str(valid_input),
"--conf", "0.2",
"--output", f"results/predictions_{args.exp_name}_original.txt",
],
"步骤5: 在原始valid数据上测试对比",
cwd=project_root
)
print("\n" + "=" * 80)
print("🎉 完整流程完成!")
print("=" * 80)
print("\n📊 查看结果:")
print(f" - 预处理后的训练数据: {digit_validation_output}")
print(f" - 预处理后的验证数据: {valid_output}")
print(f" - YOLO数据集: {yolo_dataset_output}")
print(f" - 训练模型: {best_model}")
print(f" - 识别结果(预处理数据): results/predictions_{args.exp_name}.txt")
print(f" - 识别结果(原始数据): results/predictions_{args.exp_name}_original.txt")
print(f" - 可视化结果: results/visualizations/")
print()
if __name__ == "__main__":
main()