Files
digit-cracker/scripts/prepare_yolo_dataset.py
2025-10-30 15:40:56 +08:00

281 lines
8.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
COCO到YOLO数据集转换工具
功能说明:
将COCO格式的数字标注数据集转换为YOLO训练所需的格式。
COCO格式使用JSON存储标注YOLO格式使用文本文件存储。
主要功能:
- 解析COCO格式的标注文件coco.json
- 提取图片和边界框信息
- 转换边界框格式COCO [x,y,w,h] → YOLO [x_center,y_center,w,h] (归一化)
- 自动划分训练集和验证集
- 创建YOLO标准目录结构
- 生成dataset.yaml配置文件
格式转换详解:
COCO格式:
- bbox: [x, y, width, height] (像素坐标,左上角)
- 绝对坐标,单位为像素
YOLO格式:
- bbox: [x_center, y_center, width, height] (归一化坐标)
- 相对坐标值在0-1之间
- x_center = (x + width/2) / img_width
- y_center = (y + height/2) / img_height
目录结构:
输入COCO格式:
digit-validation/
├── coco.json # 标注文件
└── images/ # 图片文件
输出YOLO格式:
yolo_dataset/
├── dataset.yaml # 配置文件
├── images/
│ ├── train/ # 训练集图片
│ └── val/ # 验证集图片
└── labels/
├── train/ # 训练集标注(.txt)
└── val/ # 验证集标注(.txt)
数据划分:
- 使用固定随机种子保证可重复性
- 默认20%作为验证集
- 保持图片和标注的对应关系
使用示例:
# 基础使用(默认参数)
python scripts/prepare_yolo_dataset.py
# 自定义参数
python scripts/prepare_yolo_dataset.py \
--root digit-validation \
--out yolo_dataset \
--val-ratio 0.2 \
--seed 42
# 只使用训练集(不划分验证集)
python scripts/prepare_yolo_dataset.py --val-ratio 0.0
注意事项:
- 确保coco.json中的file_name与实际图片文件匹配
- 类别ID必须是0-9数字
- 边界框坐标不能超出图片范围
- 输出目录会被覆盖,注意备份
作者: Gavin Chan
版本: 1.0
日期: 2025-10-30
"""
from __future__ import annotations
import argparse
import json
import random
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List
@dataclass
class CocoImage:
id: int
file_name: str
width: int
height: int
@dataclass
class CocoAnnotation:
id: int
image_id: int
bbox: List[float] # x, y, width, height
property_info: str
def parse_args() -> argparse.Namespace:
"""
解析命令行参数
Returns:
argparse.Namespace: 参数对象
- root: COCO数据集根目录
- out: YOLO数据集输出目录
- val_ratio: 验证集比例0-1
- seed: 随机种子(保证可重复性)
"""
parser = argparse.ArgumentParser(description="Prepare YOLO dataset from digit-validation COCO json")
parser.add_argument("--root", type=Path, default=Path("digit-validation"), help="digit-validation directory")
parser.add_argument("--out", type=Path, default=Path("yolo_dataset"), help="output dataset directory")
parser.add_argument("--val-ratio", type=float, default=0.2, help="validation split ratio")
parser.add_argument("--seed", type=int, default=20240305, help="random seed")
return parser.parse_args()
def load_coco(root: Path) -> tuple[List[CocoImage], List[CocoAnnotation]]:
"""
加载并解析COCO格式的标注文件
Args:
root (Path): COCO数据集根目录应包含coco.json文件
Returns:
tuple: 二元组 (images, annotations)
- images: 图片信息列表CocoImage对象
- annotations: 标注信息列表CocoAnnotation对象
Raises:
FileNotFoundError: 如果coco.json不存在
JSONDecodeError: 如果JSON格式错误
"""
coco_path = root / "coco.json"
if not coco_path.exists():
raise FileNotFoundError(f"COCO file not found at {coco_path}")
with coco_path.open("r", encoding="utf-8") as f:
data = json.load(f)
images = [
CocoImage(id=img["id"], file_name=img["file_name"], width=img["width"], height=img["height"])
for img in data["images"]
]
annotations = [
CocoAnnotation(
id=ann["id"],
image_id=ann["image_id"],
bbox=ann["bbox"],
property_info=ann.get("property_info", "").strip(),
)
for ann in data["annotations"]
]
return images, annotations
def ensure_dirs(out_root: Path) -> Dict[str, Path]:
"""
创建YOLO数据集所需的目录结构
创建的目录:
- images/train/ 训练集图片
- images/val/ 验证集图片
- labels/train/ 训练集标注
- labels/val/ 验证集标注
Args:
out_root (Path): 输出根目录
Returns:
Dict[str, Path]: 目录路径字典
- images_train: 训练集图片目录
- images_val: 验证集图片目录
- labels_train: 训练集标注目录
- labels_val: 验证集标注目录
"""
dirs = {
"images_train": out_root / "images" / "train",
"images_val": out_root / "images" / "val",
"labels_train": out_root / "labels" / "train",
"labels_val": out_root / "labels" / "val",
}
for directory in dirs.values():
directory.mkdir(parents=True, exist_ok=True)
return dirs
def coco_to_yolo(
images: List[CocoImage],
annotations: List[CocoAnnotation],
image_dir: Path,
out_root: Path,
val_ratio: float,
seed: int,
) -> Path:
id_to_image = {image.id: image for image in images}
image_to_annotations: Dict[int, List[CocoAnnotation]] = {}
for ann in annotations:
image_to_annotations.setdefault(ann.image_id, []).append(ann)
valid_images = [img for img in images if (image_dir / img.file_name).exists()]
random.Random(seed).shuffle(valid_images)
split_idx = int(len(valid_images) * (1 - val_ratio))
train_imgs = valid_images[:split_idx]
val_imgs = valid_images[split_idx:]
dirs = ensure_dirs(out_root)
def process(image: CocoImage, split: str) -> None:
src_path = image_dir / image.file_name
dst_img_dir = dirs["images_train"] if split == "train" else dirs["images_val"]
dst_lbl_dir = dirs["labels_train"] if split == "train" else dirs["labels_val"]
dst_img_path = dst_img_dir / image.file_name
dst_img_path.write_bytes(src_path.read_bytes())
anns = image_to_annotations.get(image.id, [])
lines: List[str] = []
for ann in anns:
digit_str = ann.property_info.strip()
if not digit_str.isdigit():
continue
digit = int(digit_str)
if digit < 0 or digit > 9:
continue
x, y, w, h = ann.bbox
x_center = (x + w / 2) / image.width
y_center = (y + h / 2) / image.height
w_norm = w / image.width
h_norm = h / image.height
lines.append(f"{digit} {x_center:.6f} {y_center:.6f} {w_norm:.6f} {h_norm:.6f}")
dst_lbl_path = dst_lbl_dir / (image.file_name.rsplit(".", 1)[0] + ".txt")
dst_lbl_path.write_text("\n".join(lines), encoding="utf-8")
for img in train_imgs:
process(img, "train")
for img in val_imgs:
process(img, "val")
data_yaml = {
"path": str(out_root.resolve()),
"train": "images/train",
"val": "images/val",
"names": {i: str(i) for i in range(10)},
}
data_yaml_path = out_root / "dataset.yaml"
with data_yaml_path.open("w", encoding="utf-8") as f:
f.write("# auto-generated by scripts/prepare_yolo_dataset.py\n")
for key, value in data_yaml.items():
if isinstance(value, dict):
f.write(f"{key}:\n")
for k, v in value.items():
f.write(f" {k}: {v}\n")
else:
f.write(f"{key}: {value}\n")
print(f"YOLO dataset prepared at {out_root}")
print(f"Train images: {len(train_imgs)}, Val images: {len(val_imgs)}")
print(f"Data config written to: {data_yaml_path}")
return data_yaml_path
def main() -> None:
args = parse_args()
image_dir = args.root / "images"
if not image_dir.exists():
raise FileNotFoundError(f"Images directory not found at {image_dir}")
images, annotations = load_coco(args.root)
coco_to_yolo(images, annotations, image_dir, args.out, args.val_ratio, args.seed)
if __name__ == "__main__":
main()