106 lines
3.4 KiB
TypeScript
106 lines
3.4 KiB
TypeScript
import { BoundingBox, Rectangle } from './types';
|
||
import { calculateIoU } from './utils/geometry';
|
||
|
||
class SliderValidator {
|
||
|
||
/**
|
||
* 检查两个框是否匹配(允许一定偏差)
|
||
*/
|
||
isBoxMatching(detected: Rectangle, target: Rectangle, tolerance: number = 10): boolean {
|
||
// 计算中心点
|
||
const detectedCenterX = detected.x + detected.width / 2;
|
||
const detectedCenterY = detected.y + detected.height / 2;
|
||
const targetCenterX = target.x + target.width / 2;
|
||
const targetCenterY = target.y + target.height / 2;
|
||
|
||
// 中心点距离
|
||
const centerDistance = Math.sqrt(
|
||
Math.pow(detectedCenterX - targetCenterX, 2) +
|
||
Math.pow(detectedCenterY - targetCenterY, 2)
|
||
);
|
||
|
||
// 尺寸差异 - 允许更大的容差,因为形态学操作可能改变大小
|
||
const widthDiff = Math.abs(detected.width - target.width);
|
||
const heightDiff = Math.abs(detected.height - target.height);
|
||
|
||
// 如果中心点距离小于容差,且尺寸差异不太大,认为匹配
|
||
// 放宽尺寸容差到30px(考虑到形态学操作的影响)
|
||
return centerDistance <= tolerance && widthDiff <= 30 && heightDiff <= 30;
|
||
}
|
||
|
||
/**
|
||
* 计算IoU(交并比)
|
||
*/
|
||
calculateIoU(box1: Rectangle, box2: Rectangle): number {
|
||
return calculateIoU(box1, box2);
|
||
}
|
||
|
||
/**
|
||
* 验证检测结果
|
||
*/
|
||
async validateDetection(
|
||
detectedBoxes: Rectangle[],
|
||
targetBoxes: Rectangle[],
|
||
tolerance: number = 10
|
||
): Promise<{
|
||
totalTargets: number;
|
||
detectedCount: number;
|
||
matchedCount: number;
|
||
precision: number;
|
||
recall: number;
|
||
matches: Array<{ detected: Rectangle; target: Rectangle; iou: number }>;
|
||
unmatched: Rectangle[];
|
||
}> {
|
||
const matches: Array<{ detected: Rectangle; target: Rectangle; iou: number }> = [];
|
||
const matchedTargets = new Set<number>();
|
||
const matchedDetected = new Set<number>();
|
||
|
||
// 1. 找出所有可能的匹配对
|
||
const potentialMatches: Array<{ detIdx: number; tarIdx: number; iou: number }> = [];
|
||
for (let i = 0; i < detectedBoxes.length; i++) {
|
||
for (let j = 0; j < targetBoxes.length; j++) {
|
||
if (this.isBoxMatching(detectedBoxes[i], targetBoxes[j], tolerance)) {
|
||
const iou = this.calculateIoU(detectedBoxes[i], targetBoxes[j]);
|
||
if (iou > 0.1) { // 设置一个IoU的下限
|
||
potentialMatches.push({ detIdx: i, tarIdx: j, iou });
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 2. 按IoU从高到低排序
|
||
potentialMatches.sort((a, b) => b.iou - a.iou);
|
||
|
||
// 3. 贪心选择最佳匹配
|
||
for (const match of potentialMatches) {
|
||
if (!matchedDetected.has(match.detIdx) && !matchedTargets.has(match.tarIdx)) {
|
||
matches.push({
|
||
detected: detectedBoxes[match.detIdx],
|
||
target: targetBoxes[match.tarIdx],
|
||
iou: match.iou
|
||
});
|
||
matchedDetected.add(match.detIdx);
|
||
matchedTargets.add(match.tarIdx);
|
||
}
|
||
}
|
||
|
||
// 未匹配的检测框
|
||
const unmatched = detectedBoxes.filter((_, i) => !matchedDetected.has(i));
|
||
|
||
const precision = detectedBoxes.length > 0 ? matches.length / detectedBoxes.length : 0;
|
||
const recall = targetBoxes.length > 0 ? matches.length / targetBoxes.length : 0;
|
||
|
||
return {
|
||
totalTargets: targetBoxes.length,
|
||
detectedCount: detectedBoxes.length,
|
||
matchedCount: matches.length,
|
||
precision,
|
||
recall,
|
||
matches,
|
||
unmatched
|
||
};
|
||
}
|
||
}
|
||
|
||
export { SliderValidator, BoundingBox, Rectangle };
|