281 lines
8.7 KiB
Python
281 lines
8.7 KiB
Python
"""
|
||
COCO到YOLO数据集转换工具
|
||
|
||
功能说明:
|
||
将COCO格式的数字标注数据集转换为YOLO训练所需的格式。
|
||
COCO格式使用JSON存储标注,YOLO格式使用文本文件存储。
|
||
|
||
主要功能:
|
||
- 解析COCO格式的标注文件(coco.json)
|
||
- 提取图片和边界框信息
|
||
- 转换边界框格式:COCO [x,y,w,h] → YOLO [x_center,y_center,w,h] (归一化)
|
||
- 自动划分训练集和验证集
|
||
- 创建YOLO标准目录结构
|
||
- 生成dataset.yaml配置文件
|
||
|
||
格式转换详解:
|
||
COCO格式:
|
||
- bbox: [x, y, width, height] (像素坐标,左上角)
|
||
- 绝对坐标,单位为像素
|
||
|
||
YOLO格式:
|
||
- bbox: [x_center, y_center, width, height] (归一化坐标)
|
||
- 相对坐标,值在0-1之间
|
||
- x_center = (x + width/2) / img_width
|
||
- y_center = (y + height/2) / img_height
|
||
|
||
目录结构:
|
||
输入(COCO格式):
|
||
digit-validation/
|
||
├── coco.json # 标注文件
|
||
└── images/ # 图片文件
|
||
|
||
输出(YOLO格式):
|
||
yolo_dataset/
|
||
├── dataset.yaml # 配置文件
|
||
├── images/
|
||
│ ├── train/ # 训练集图片
|
||
│ └── val/ # 验证集图片
|
||
└── labels/
|
||
├── train/ # 训练集标注(.txt)
|
||
└── val/ # 验证集标注(.txt)
|
||
|
||
数据划分:
|
||
- 使用固定随机种子保证可重复性
|
||
- 默认20%作为验证集
|
||
- 保持图片和标注的对应关系
|
||
|
||
使用示例:
|
||
# 基础使用(默认参数)
|
||
python scripts/prepare_yolo_dataset.py
|
||
|
||
# 自定义参数
|
||
python scripts/prepare_yolo_dataset.py \
|
||
--root digit-validation \
|
||
--out yolo_dataset \
|
||
--val-ratio 0.2 \
|
||
--seed 42
|
||
|
||
# 只使用训练集(不划分验证集)
|
||
python scripts/prepare_yolo_dataset.py --val-ratio 0.0
|
||
|
||
注意事项:
|
||
- 确保coco.json中的file_name与实际图片文件匹配
|
||
- 类别ID必须是0-9(数字)
|
||
- 边界框坐标不能超出图片范围
|
||
- 输出目录会被覆盖,注意备份
|
||
|
||
作者: Gavin Chan
|
||
版本: 1.0
|
||
日期: 2025-10-30
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import json
|
||
import random
|
||
from dataclasses import dataclass
|
||
from pathlib import Path
|
||
from typing import Dict, List
|
||
|
||
|
||
@dataclass
|
||
class CocoImage:
|
||
id: int
|
||
file_name: str
|
||
width: int
|
||
height: int
|
||
|
||
|
||
@dataclass
|
||
class CocoAnnotation:
|
||
id: int
|
||
image_id: int
|
||
bbox: List[float] # x, y, width, height
|
||
property_info: str
|
||
|
||
|
||
def parse_args() -> argparse.Namespace:
|
||
"""
|
||
解析命令行参数
|
||
|
||
Returns:
|
||
argparse.Namespace: 参数对象
|
||
- root: COCO数据集根目录
|
||
- out: YOLO数据集输出目录
|
||
- val_ratio: 验证集比例(0-1)
|
||
- seed: 随机种子(保证可重复性)
|
||
"""
|
||
parser = argparse.ArgumentParser(description="Prepare YOLO dataset from digit-validation COCO json")
|
||
parser.add_argument("--root", type=Path, default=Path("digit-validation"), help="digit-validation directory")
|
||
parser.add_argument("--out", type=Path, default=Path("yolo_dataset"), help="output dataset directory")
|
||
parser.add_argument("--val-ratio", type=float, default=0.2, help="validation split ratio")
|
||
parser.add_argument("--seed", type=int, default=20240305, help="random seed")
|
||
return parser.parse_args()
|
||
|
||
|
||
def load_coco(root: Path) -> tuple[List[CocoImage], List[CocoAnnotation]]:
|
||
"""
|
||
加载并解析COCO格式的标注文件
|
||
|
||
Args:
|
||
root (Path): COCO数据集根目录,应包含coco.json文件
|
||
|
||
Returns:
|
||
tuple: 二元组 (images, annotations)
|
||
- images: 图片信息列表(CocoImage对象)
|
||
- annotations: 标注信息列表(CocoAnnotation对象)
|
||
|
||
Raises:
|
||
FileNotFoundError: 如果coco.json不存在
|
||
JSONDecodeError: 如果JSON格式错误
|
||
"""
|
||
coco_path = root / "coco.json"
|
||
if not coco_path.exists():
|
||
raise FileNotFoundError(f"COCO file not found at {coco_path}")
|
||
|
||
with coco_path.open("r", encoding="utf-8") as f:
|
||
data = json.load(f)
|
||
|
||
images = [
|
||
CocoImage(id=img["id"], file_name=img["file_name"], width=img["width"], height=img["height"])
|
||
for img in data["images"]
|
||
]
|
||
annotations = [
|
||
CocoAnnotation(
|
||
id=ann["id"],
|
||
image_id=ann["image_id"],
|
||
bbox=ann["bbox"],
|
||
property_info=ann.get("property_info", "").strip(),
|
||
)
|
||
for ann in data["annotations"]
|
||
]
|
||
return images, annotations
|
||
|
||
|
||
def ensure_dirs(out_root: Path) -> Dict[str, Path]:
|
||
"""
|
||
创建YOLO数据集所需的目录结构
|
||
|
||
创建的目录:
|
||
- images/train/ 训练集图片
|
||
- images/val/ 验证集图片
|
||
- labels/train/ 训练集标注
|
||
- labels/val/ 验证集标注
|
||
|
||
Args:
|
||
out_root (Path): 输出根目录
|
||
|
||
Returns:
|
||
Dict[str, Path]: 目录路径字典
|
||
- images_train: 训练集图片目录
|
||
- images_val: 验证集图片目录
|
||
- labels_train: 训练集标注目录
|
||
- labels_val: 验证集标注目录
|
||
"""
|
||
dirs = {
|
||
"images_train": out_root / "images" / "train",
|
||
"images_val": out_root / "images" / "val",
|
||
"labels_train": out_root / "labels" / "train",
|
||
"labels_val": out_root / "labels" / "val",
|
||
}
|
||
for directory in dirs.values():
|
||
directory.mkdir(parents=True, exist_ok=True)
|
||
return dirs
|
||
|
||
|
||
def coco_to_yolo(
|
||
images: List[CocoImage],
|
||
annotations: List[CocoAnnotation],
|
||
image_dir: Path,
|
||
out_root: Path,
|
||
val_ratio: float,
|
||
seed: int,
|
||
) -> Path:
|
||
id_to_image = {image.id: image for image in images}
|
||
image_to_annotations: Dict[int, List[CocoAnnotation]] = {}
|
||
for ann in annotations:
|
||
image_to_annotations.setdefault(ann.image_id, []).append(ann)
|
||
|
||
valid_images = [img for img in images if (image_dir / img.file_name).exists()]
|
||
random.Random(seed).shuffle(valid_images)
|
||
|
||
split_idx = int(len(valid_images) * (1 - val_ratio))
|
||
train_imgs = valid_images[:split_idx]
|
||
val_imgs = valid_images[split_idx:]
|
||
|
||
dirs = ensure_dirs(out_root)
|
||
|
||
def process(image: CocoImage, split: str) -> None:
|
||
src_path = image_dir / image.file_name
|
||
dst_img_dir = dirs["images_train"] if split == "train" else dirs["images_val"]
|
||
dst_lbl_dir = dirs["labels_train"] if split == "train" else dirs["labels_val"]
|
||
|
||
dst_img_path = dst_img_dir / image.file_name
|
||
dst_img_path.write_bytes(src_path.read_bytes())
|
||
|
||
anns = image_to_annotations.get(image.id, [])
|
||
lines: List[str] = []
|
||
|
||
for ann in anns:
|
||
digit_str = ann.property_info.strip()
|
||
if not digit_str.isdigit():
|
||
continue
|
||
digit = int(digit_str)
|
||
if digit < 0 or digit > 9:
|
||
continue
|
||
|
||
x, y, w, h = ann.bbox
|
||
x_center = (x + w / 2) / image.width
|
||
y_center = (y + h / 2) / image.height
|
||
w_norm = w / image.width
|
||
h_norm = h / image.height
|
||
|
||
lines.append(f"{digit} {x_center:.6f} {y_center:.6f} {w_norm:.6f} {h_norm:.6f}")
|
||
|
||
dst_lbl_path = dst_lbl_dir / (image.file_name.rsplit(".", 1)[0] + ".txt")
|
||
dst_lbl_path.write_text("\n".join(lines), encoding="utf-8")
|
||
|
||
for img in train_imgs:
|
||
process(img, "train")
|
||
for img in val_imgs:
|
||
process(img, "val")
|
||
|
||
data_yaml = {
|
||
"path": str(out_root.resolve()),
|
||
"train": "images/train",
|
||
"val": "images/val",
|
||
"names": {i: str(i) for i in range(10)},
|
||
}
|
||
|
||
data_yaml_path = out_root / "dataset.yaml"
|
||
with data_yaml_path.open("w", encoding="utf-8") as f:
|
||
f.write("# auto-generated by scripts/prepare_yolo_dataset.py\n")
|
||
for key, value in data_yaml.items():
|
||
if isinstance(value, dict):
|
||
f.write(f"{key}:\n")
|
||
for k, v in value.items():
|
||
f.write(f" {k}: {v}\n")
|
||
else:
|
||
f.write(f"{key}: {value}\n")
|
||
|
||
print(f"YOLO dataset prepared at {out_root}")
|
||
print(f"Train images: {len(train_imgs)}, Val images: {len(val_imgs)}")
|
||
print(f"Data config written to: {data_yaml_path}")
|
||
return data_yaml_path
|
||
|
||
|
||
def main() -> None:
|
||
args = parse_args()
|
||
image_dir = args.root / "images"
|
||
if not image_dir.exists():
|
||
raise FileNotFoundError(f"Images directory not found at {image_dir}")
|
||
|
||
images, annotations = load_coco(args.root)
|
||
coco_to_yolo(images, annotations, image_dir, args.out, args.val_ratio, args.seed)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|