first commit
This commit is contained in:
435
scripts/compare_results.py
Normal file
435
scripts/compare_results.py
Normal file
@@ -0,0 +1,435 @@
|
||||
"""
|
||||
YOLO数字识别结果对比工具
|
||||
|
||||
功能说明:
|
||||
对比两个模型或不同配置的识别结果,生成详细的对比报告。
|
||||
主要用于评估预处理、模型优化等改进措施的效果。
|
||||
|
||||
主要功能:
|
||||
- 加载并解析两个识别结果文件
|
||||
- 统计整体准确率、置信度等指标
|
||||
- 逐张图片对比识别结果
|
||||
- 标识改进案例和退化案例
|
||||
- 生成Markdown格式的详细报告
|
||||
|
||||
对比维度:
|
||||
1. 整体统计:
|
||||
- 识别准确率(识别出4位数字的比例)
|
||||
- 平均置信度
|
||||
- 改进幅度
|
||||
|
||||
2. 详细对比:
|
||||
- 每张图片的识别结果对比
|
||||
- 置信度对比
|
||||
- 状态标识(改进/退化/保持/未改善)
|
||||
|
||||
3. 改进分析:
|
||||
- 新增正确识别的图片列表
|
||||
- 识别退化的图片列表
|
||||
- 改进建议
|
||||
|
||||
报告格式:
|
||||
生成的Markdown报告包含:
|
||||
- 📊 整体统计表格
|
||||
- 📝 详细对比表格
|
||||
- 🎯 改进案例列表
|
||||
- ⚠️ 退化案例列表
|
||||
- 📌 结论和建议
|
||||
|
||||
使用场景:
|
||||
场景1: 对比预处理效果
|
||||
python scripts/compare_results.py \
|
||||
--original results/predictions_original.txt \
|
||||
--preprocessed results/predictions_preprocessed.txt \
|
||||
--output results/preprocessing_comparison.md
|
||||
|
||||
场景2: 对比不同模型
|
||||
python scripts/compare_results.py \
|
||||
--original results/predictions_exp1.txt \
|
||||
--preprocessed results/predictions_exp2.txt \
|
||||
--output results/model_comparison.md
|
||||
|
||||
场景3: 对比不同置信度阈值
|
||||
python scripts/compare_results.py \
|
||||
--original results/predictions_conf02.txt \
|
||||
--preprocessed results/predictions_conf01.txt \
|
||||
--output results/threshold_comparison.md
|
||||
|
||||
输入格式:
|
||||
识别结果文件应为制表符分隔的文本文件:
|
||||
```
|
||||
文件名 识别结果 置信度 数字个数
|
||||
YZM.jpeg 3809 0.584 4
|
||||
YZM-2.jpeg 87 0.358 2
|
||||
```
|
||||
|
||||
输出示例:
|
||||
```markdown
|
||||
# 预处理效果对比报告
|
||||
|
||||
## 📊 整体统计
|
||||
| 指标 | 原始模型 | 预处理模型 | 改进 |
|
||||
|------|----------|------------|------|
|
||||
| 识别准确率 | 20.0% (3/15) | 80.0% (12/15) | +60.0% |
|
||||
| 平均置信度 | 0.512 | 0.653 | +0.141 |
|
||||
|
||||
## 🎯 改进案例
|
||||
预处理后新增识别正确的图片(9张):
|
||||
- **YZM-11.jpeg**: 53 (2位) → 5389 (4位) ✅
|
||||
...
|
||||
```
|
||||
|
||||
依赖环境:
|
||||
- Python 3.8+
|
||||
- 无第三方依赖(仅使用标准库)
|
||||
|
||||
注意事项:
|
||||
- 两个结果文件应该是在相同图片集上的识别结果
|
||||
- 文件名必须对应才能正确对比
|
||||
- 结果文件格式必须正确(制表符分隔)
|
||||
|
||||
作者: Gavin Chan
|
||||
版本: 1.0
|
||||
日期: 2025-10-30
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
"""
|
||||
解析命令行参数
|
||||
|
||||
Returns:
|
||||
argparse.Namespace: 参数对象
|
||||
- original: 原始模型的识别结果文件路径
|
||||
- preprocessed: 优化后模型的识别结果文件路径
|
||||
- output: 对比报告输出文件路径(Markdown格式)
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="对比预处理前后的识别效果")
|
||||
parser.add_argument(
|
||||
"--original",
|
||||
type=Path,
|
||||
default=Path("results/predictions_improved.txt"),
|
||||
help="原始模型的识别结果文件"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--preprocessed",
|
||||
type=Path,
|
||||
default=Path("results/predictions_exp_preprocessed_150.txt"),
|
||||
help="预处理后模型的识别结果文件"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=Path,
|
||||
default=Path("results/comparison_report.md"),
|
||||
help="对比报告输出文件"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_results(file_path: Path) -> Dict[str, Dict[str, any]]:
|
||||
"""
|
||||
加载并解析识别结果文件
|
||||
|
||||
文件格式:
|
||||
制表符分隔的文本文件,格式如下:
|
||||
文件名 识别结果 置信度 数字个数
|
||||
YZM.jpeg 3809 0.584 4
|
||||
...
|
||||
|
||||
处理流程:
|
||||
1. 检查文件是否存在
|
||||
2. 读取所有行
|
||||
3. 跳过标题行(第一行)
|
||||
4. 解析每一行的数据
|
||||
5. 将结果存储为字典
|
||||
|
||||
Args:
|
||||
file_path (Path): 识别结果文件路径
|
||||
|
||||
Returns:
|
||||
Dict[str, Dict]: 识别结果字典
|
||||
键: 文件名(如 "YZM.jpeg")
|
||||
值: 字典包含
|
||||
- digits: 识别出的数字字符串
|
||||
- confidence: 平均置信度(float)
|
||||
- digit_count: 识别出的数字个数(int)
|
||||
- correct: 是否正确识别4位(bool)
|
||||
|
||||
异常处理:
|
||||
- 文件不存在: 打印警告并返回空字典
|
||||
- 格式错误: 跳过该行继续处理
|
||||
|
||||
示例:
|
||||
>>> results = load_results(Path("results/predictions.txt"))
|
||||
>>> print(results["YZM.jpeg"])
|
||||
{'digits': '3809', 'confidence': 0.584, 'digit_count': 4, 'correct': True}
|
||||
"""
|
||||
results = {}
|
||||
|
||||
if not file_path.exists():
|
||||
print(f"警告: 文件不存在 {file_path}")
|
||||
return results
|
||||
|
||||
with file_path.open('r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
# 跳过标题行
|
||||
for line in lines[1:]:
|
||||
parts = line.strip().split('\t')
|
||||
if len(parts) >= 4:
|
||||
filename = parts[0]
|
||||
digits = parts[1]
|
||||
confidence = float(parts[2])
|
||||
digit_count = int(parts[3])
|
||||
|
||||
results[filename] = {
|
||||
'digits': digits,
|
||||
'confidence': confidence,
|
||||
'digit_count': digit_count,
|
||||
'correct': digit_count == 4
|
||||
}
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def generate_comparison_report(
|
||||
original_results: Dict,
|
||||
preprocessed_results: Dict,
|
||||
output_path: Path
|
||||
) -> None:
|
||||
"""
|
||||
生成详细的Markdown格式对比报告
|
||||
|
||||
报告内容:
|
||||
1. 整体统计表格
|
||||
- 识别准确率对比
|
||||
- 平均置信度对比
|
||||
- 改进幅度
|
||||
|
||||
2. 详细对比表格
|
||||
- 每张图片的识别结果
|
||||
- 置信度变化
|
||||
- 状态标识(改进/退化/保持/未改善)
|
||||
|
||||
3. 改进案例
|
||||
- 列出从错误到正确的图片
|
||||
- 显示具体的改进效果
|
||||
|
||||
4. 退化案例
|
||||
- 列出从正确到错误的图片
|
||||
- 分析可能的原因
|
||||
|
||||
5. 结论和建议
|
||||
- 总结改进效果
|
||||
- 提供优化建议
|
||||
|
||||
状态判断逻辑:
|
||||
- ✅ 改进: 原来错误,现在正确(最重要)
|
||||
- ❌ 退化: 原来正确,现在错误(需要关注)
|
||||
- ✓ 保持: 两次都正确(稳定)
|
||||
- - 未改善: 两次都错误(仍需改进)
|
||||
|
||||
Args:
|
||||
original_results (Dict): 原始模型的识别结果
|
||||
格式: {文件名: {digits, confidence, digit_count, correct}}
|
||||
preprocessed_results (Dict): 优化后模型的识别结果
|
||||
格式同上
|
||||
output_path (Path): 报告输出文件路径(.md文件)
|
||||
|
||||
Returns:
|
||||
None: 报告直接写入文件
|
||||
|
||||
输出示例:
|
||||
生成的报告包含完整的统计、对比和分析信息,
|
||||
便于评估优化效果和发现问题。
|
||||
|
||||
注意:
|
||||
- 会覆盖已存在的输出文件
|
||||
- 确保有足够的磁盘空间
|
||||
- 文件使用UTF-8编码
|
||||
"""
|
||||
# 统计
|
||||
original_correct = sum(1 for r in original_results.values() if r['correct'])
|
||||
preprocessed_correct = sum(1 for r in preprocessed_results.values() if r['correct'])
|
||||
|
||||
total_images = len(original_results)
|
||||
|
||||
original_accuracy = (original_correct / total_images * 100) if total_images > 0 else 0
|
||||
preprocessed_accuracy = (preprocessed_correct / total_images * 100) if total_images > 0 else 0
|
||||
|
||||
improvement = preprocessed_accuracy - original_accuracy
|
||||
|
||||
# 生成报告
|
||||
with output_path.open('w', encoding='utf-8') as f:
|
||||
f.write("# 预处理效果对比报告\n\n")
|
||||
f.write("## 📊 整体统计\n\n")
|
||||
f.write(f"| 指标 | 原始模型 | 预处理模型 | 改进 |\n")
|
||||
f.write(f"|------|----------|------------|------|\n")
|
||||
f.write(f"| 识别准确率 | {original_accuracy:.1f}% ({original_correct}/{total_images}) | {preprocessed_accuracy:.1f}% ({preprocessed_correct}/{total_images}) | {improvement:+.1f}% |\n")
|
||||
|
||||
# 平均置信度
|
||||
original_avg_conf = sum(r['confidence'] for r in original_results.values()) / len(original_results) if original_results else 0
|
||||
preprocessed_avg_conf = sum(r['confidence'] for r in preprocessed_results.values()) / len(preprocessed_results) if preprocessed_results else 0
|
||||
|
||||
f.write(f"| 平均置信度 | {original_avg_conf:.3f} | {preprocessed_avg_conf:.3f} | {preprocessed_avg_conf - original_avg_conf:+.3f} |\n\n")
|
||||
|
||||
# 详细对比
|
||||
f.write("## 📝 详细对比\n\n")
|
||||
f.write("| 文件名 | 原始识别 | 置信度 | 预处理识别 | 置信度 | 状态 |\n")
|
||||
f.write("|--------|----------|--------|------------|--------|------|\n")
|
||||
|
||||
for filename in sorted(original_results.keys()):
|
||||
orig = original_results[filename]
|
||||
prep = preprocessed_results.get(filename, {'digits': 'N/A', 'confidence': 0.0, 'correct': False})
|
||||
|
||||
# 判断状态
|
||||
if not orig['correct'] and prep['correct']:
|
||||
status = "✅ 改进"
|
||||
elif orig['correct'] and not prep['correct']:
|
||||
status = "❌ 退化"
|
||||
elif orig['correct'] and prep['correct']:
|
||||
status = "✓ 保持"
|
||||
else:
|
||||
status = "- 未改善"
|
||||
|
||||
f.write(f"| {filename} | {orig['digits'] or '-'} | {orig['confidence']:.3f} | {prep['digits'] or '-'} | {prep['confidence']:.3f} | {status} |\n")
|
||||
|
||||
# 改进案例
|
||||
f.write("\n## 🎯 改进案例\n\n")
|
||||
improved = [fn for fn in original_results.keys()
|
||||
if not original_results[fn]['correct'] and preprocessed_results.get(fn, {}).get('correct', False)]
|
||||
|
||||
if improved:
|
||||
f.write(f"预处理后新增识别正确的图片({len(improved)}张):\n\n")
|
||||
for fn in improved:
|
||||
orig = original_results[fn]
|
||||
prep = preprocessed_results[fn]
|
||||
f.write(f"- **{fn}**: {orig['digits'] or '(无)'} ({orig['digit_count']}位) → {prep['digits']} (4位) ✅\n")
|
||||
else:
|
||||
f.write("暂无新增正确识别的图片\n")
|
||||
|
||||
# 退化案例
|
||||
f.write("\n## ⚠️ 退化案例\n\n")
|
||||
regressed = [fn for fn in original_results.keys()
|
||||
if original_results[fn]['correct'] and not preprocessed_results.get(fn, {}).get('correct', False)]
|
||||
|
||||
if regressed:
|
||||
f.write(f"预处理后识别错误的图片({len(regressed)}张):\n\n")
|
||||
for fn in regressed:
|
||||
orig = original_results[fn]
|
||||
prep = preprocessed_results[fn]
|
||||
f.write(f"- **{fn}**: {orig['digits']} (4位) → {prep['digits'] or '(无)'} ({prep['digit_count']}位) ❌\n")
|
||||
else:
|
||||
f.write("没有退化案例 ✓\n")
|
||||
|
||||
# 结论
|
||||
f.write("\n## 📌 结论\n\n")
|
||||
if improvement > 0:
|
||||
f.write(f"✅ **预处理有效**:准确率提升 {improvement:.1f}%\n\n")
|
||||
f.write("预处理(去噪+对比度增强+灰度化)对提升数字识别效果有积极作用。\n")
|
||||
elif improvement < 0:
|
||||
f.write(f"⚠️ **预处理效果不佳**:准确率下降 {abs(improvement):.1f}%\n\n")
|
||||
f.write("预处理可能过度处理了图片,建议:\n")
|
||||
f.write("- 尝试其他预处理方法(如 --method clahe 或 combined)\n")
|
||||
f.write("- 调整预处理参数\n")
|
||||
f.write("- 保持彩色图片(--keep-color)\n")
|
||||
else:
|
||||
f.write("预处理效果与原始模型相当。\n")
|
||||
|
||||
f.write("\n---\n")
|
||||
f.write("*报告生成时间: 2025-10-30*\n")
|
||||
|
||||
print(f"✓ 对比报告已生成: {output_path}")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""
|
||||
主函数:执行结果对比流程
|
||||
|
||||
完整流程:
|
||||
1. 解析命令行参数
|
||||
2. 加载两个识别结果文件
|
||||
3. 验证数据有效性
|
||||
4. 生成详细对比报告(Markdown)
|
||||
5. 在控制台显示简要统计
|
||||
|
||||
输出内容:
|
||||
控制台输出:
|
||||
- 加载进度信息
|
||||
- 简要统计对比
|
||||
- 准确率变化
|
||||
- 报告文件路径
|
||||
|
||||
文件输出:
|
||||
- 完整的Markdown格式对比报告
|
||||
- 包含表格、列表、统计图表等
|
||||
|
||||
异常处理:
|
||||
- 文件不存在: 打印错误并退出
|
||||
- 数据为空: 打印错误并退出
|
||||
- 其他异常向上传播
|
||||
|
||||
使用示例:
|
||||
>>> # 命令行调用
|
||||
>>> python scripts/compare_results.py \
|
||||
... --original results/predictions_v1.txt \
|
||||
... --preprocessed results/predictions_v2.txt
|
||||
|
||||
输出:
|
||||
加载识别结果...
|
||||
原始结果: 15 张图片
|
||||
预处理结果: 15 张图片
|
||||
✓ 对比报告已生成: results/comparison_report.md
|
||||
|
||||
================================================================================
|
||||
预处理效果对比
|
||||
================================================================================
|
||||
原始模型: 3/15 (20.0%)
|
||||
预处理模型: 12/15 (80.0%)
|
||||
改进: +9 (+60.0%)
|
||||
================================================================================
|
||||
"""
|
||||
args = parse_args()
|
||||
|
||||
print("加载识别结果...")
|
||||
original_results = load_results(args.original)
|
||||
preprocessed_results = load_results(args.preprocessed)
|
||||
|
||||
if not original_results:
|
||||
print(f"错误: 无法加载原始结果 {args.original}")
|
||||
return
|
||||
|
||||
if not preprocessed_results:
|
||||
print(f"错误: 无法加载预处理结果 {args.preprocessed}")
|
||||
return
|
||||
|
||||
print(f"原始结果: {len(original_results)} 张图片")
|
||||
print(f"预处理结果: {len(preprocessed_results)} 张图片")
|
||||
|
||||
# 生成报告
|
||||
args.output.parent.mkdir(parents=True, exist_ok=True)
|
||||
generate_comparison_report(original_results, preprocessed_results, args.output)
|
||||
|
||||
# 显示简要统计
|
||||
print("\n" + "=" * 80)
|
||||
print("预处理效果对比")
|
||||
print("=" * 80)
|
||||
|
||||
original_correct = sum(1 for r in original_results.values() if r['correct'])
|
||||
preprocessed_correct = sum(1 for r in preprocessed_results.values() if r['correct'])
|
||||
total = len(original_results)
|
||||
|
||||
print(f"原始模型: {original_correct}/{total} ({original_correct/total*100:.1f}%)")
|
||||
print(f"预处理模型: {preprocessed_correct}/{total} ({preprocessed_correct/total*100:.1f}%)")
|
||||
print(f"改进: {preprocessed_correct - original_correct:+d} ({(preprocessed_correct - original_correct)/total*100:+.1f}%)")
|
||||
print("=" * 80)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user