first commit
This commit is contained in:
280
scripts/prepare_yolo_dataset.py
Normal file
280
scripts/prepare_yolo_dataset.py
Normal file
@@ -0,0 +1,280 @@
|
||||
"""
|
||||
COCO到YOLO数据集转换工具
|
||||
|
||||
功能说明:
|
||||
将COCO格式的数字标注数据集转换为YOLO训练所需的格式。
|
||||
COCO格式使用JSON存储标注,YOLO格式使用文本文件存储。
|
||||
|
||||
主要功能:
|
||||
- 解析COCO格式的标注文件(coco.json)
|
||||
- 提取图片和边界框信息
|
||||
- 转换边界框格式:COCO [x,y,w,h] → YOLO [x_center,y_center,w,h] (归一化)
|
||||
- 自动划分训练集和验证集
|
||||
- 创建YOLO标准目录结构
|
||||
- 生成dataset.yaml配置文件
|
||||
|
||||
格式转换详解:
|
||||
COCO格式:
|
||||
- bbox: [x, y, width, height] (像素坐标,左上角)
|
||||
- 绝对坐标,单位为像素
|
||||
|
||||
YOLO格式:
|
||||
- bbox: [x_center, y_center, width, height] (归一化坐标)
|
||||
- 相对坐标,值在0-1之间
|
||||
- x_center = (x + width/2) / img_width
|
||||
- y_center = (y + height/2) / img_height
|
||||
|
||||
目录结构:
|
||||
输入(COCO格式):
|
||||
digit-validation/
|
||||
├── coco.json # 标注文件
|
||||
└── images/ # 图片文件
|
||||
|
||||
输出(YOLO格式):
|
||||
yolo_dataset/
|
||||
├── dataset.yaml # 配置文件
|
||||
├── images/
|
||||
│ ├── train/ # 训练集图片
|
||||
│ └── val/ # 验证集图片
|
||||
└── labels/
|
||||
├── train/ # 训练集标注(.txt)
|
||||
└── val/ # 验证集标注(.txt)
|
||||
|
||||
数据划分:
|
||||
- 使用固定随机种子保证可重复性
|
||||
- 默认20%作为验证集
|
||||
- 保持图片和标注的对应关系
|
||||
|
||||
使用示例:
|
||||
# 基础使用(默认参数)
|
||||
python scripts/prepare_yolo_dataset.py
|
||||
|
||||
# 自定义参数
|
||||
python scripts/prepare_yolo_dataset.py \
|
||||
--root digit-validation \
|
||||
--out yolo_dataset \
|
||||
--val-ratio 0.2 \
|
||||
--seed 42
|
||||
|
||||
# 只使用训练集(不划分验证集)
|
||||
python scripts/prepare_yolo_dataset.py --val-ratio 0.0
|
||||
|
||||
注意事项:
|
||||
- 确保coco.json中的file_name与实际图片文件匹配
|
||||
- 类别ID必须是0-9(数字)
|
||||
- 边界框坐标不能超出图片范围
|
||||
- 输出目录会被覆盖,注意备份
|
||||
|
||||
作者: Gavin Chan
|
||||
版本: 1.0
|
||||
日期: 2025-10-30
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
@dataclass
|
||||
class CocoImage:
|
||||
id: int
|
||||
file_name: str
|
||||
width: int
|
||||
height: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class CocoAnnotation:
|
||||
id: int
|
||||
image_id: int
|
||||
bbox: List[float] # x, y, width, height
|
||||
property_info: str
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
"""
|
||||
解析命令行参数
|
||||
|
||||
Returns:
|
||||
argparse.Namespace: 参数对象
|
||||
- root: COCO数据集根目录
|
||||
- out: YOLO数据集输出目录
|
||||
- val_ratio: 验证集比例(0-1)
|
||||
- seed: 随机种子(保证可重复性)
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="Prepare YOLO dataset from digit-validation COCO json")
|
||||
parser.add_argument("--root", type=Path, default=Path("digit-validation"), help="digit-validation directory")
|
||||
parser.add_argument("--out", type=Path, default=Path("yolo_dataset"), help="output dataset directory")
|
||||
parser.add_argument("--val-ratio", type=float, default=0.2, help="validation split ratio")
|
||||
parser.add_argument("--seed", type=int, default=20240305, help="random seed")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_coco(root: Path) -> tuple[List[CocoImage], List[CocoAnnotation]]:
|
||||
"""
|
||||
加载并解析COCO格式的标注文件
|
||||
|
||||
Args:
|
||||
root (Path): COCO数据集根目录,应包含coco.json文件
|
||||
|
||||
Returns:
|
||||
tuple: 二元组 (images, annotations)
|
||||
- images: 图片信息列表(CocoImage对象)
|
||||
- annotations: 标注信息列表(CocoAnnotation对象)
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: 如果coco.json不存在
|
||||
JSONDecodeError: 如果JSON格式错误
|
||||
"""
|
||||
coco_path = root / "coco.json"
|
||||
if not coco_path.exists():
|
||||
raise FileNotFoundError(f"COCO file not found at {coco_path}")
|
||||
|
||||
with coco_path.open("r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
images = [
|
||||
CocoImage(id=img["id"], file_name=img["file_name"], width=img["width"], height=img["height"])
|
||||
for img in data["images"]
|
||||
]
|
||||
annotations = [
|
||||
CocoAnnotation(
|
||||
id=ann["id"],
|
||||
image_id=ann["image_id"],
|
||||
bbox=ann["bbox"],
|
||||
property_info=ann.get("property_info", "").strip(),
|
||||
)
|
||||
for ann in data["annotations"]
|
||||
]
|
||||
return images, annotations
|
||||
|
||||
|
||||
def ensure_dirs(out_root: Path) -> Dict[str, Path]:
|
||||
"""
|
||||
创建YOLO数据集所需的目录结构
|
||||
|
||||
创建的目录:
|
||||
- images/train/ 训练集图片
|
||||
- images/val/ 验证集图片
|
||||
- labels/train/ 训练集标注
|
||||
- labels/val/ 验证集标注
|
||||
|
||||
Args:
|
||||
out_root (Path): 输出根目录
|
||||
|
||||
Returns:
|
||||
Dict[str, Path]: 目录路径字典
|
||||
- images_train: 训练集图片目录
|
||||
- images_val: 验证集图片目录
|
||||
- labels_train: 训练集标注目录
|
||||
- labels_val: 验证集标注目录
|
||||
"""
|
||||
dirs = {
|
||||
"images_train": out_root / "images" / "train",
|
||||
"images_val": out_root / "images" / "val",
|
||||
"labels_train": out_root / "labels" / "train",
|
||||
"labels_val": out_root / "labels" / "val",
|
||||
}
|
||||
for directory in dirs.values():
|
||||
directory.mkdir(parents=True, exist_ok=True)
|
||||
return dirs
|
||||
|
||||
|
||||
def coco_to_yolo(
|
||||
images: List[CocoImage],
|
||||
annotations: List[CocoAnnotation],
|
||||
image_dir: Path,
|
||||
out_root: Path,
|
||||
val_ratio: float,
|
||||
seed: int,
|
||||
) -> Path:
|
||||
id_to_image = {image.id: image for image in images}
|
||||
image_to_annotations: Dict[int, List[CocoAnnotation]] = {}
|
||||
for ann in annotations:
|
||||
image_to_annotations.setdefault(ann.image_id, []).append(ann)
|
||||
|
||||
valid_images = [img for img in images if (image_dir / img.file_name).exists()]
|
||||
random.Random(seed).shuffle(valid_images)
|
||||
|
||||
split_idx = int(len(valid_images) * (1 - val_ratio))
|
||||
train_imgs = valid_images[:split_idx]
|
||||
val_imgs = valid_images[split_idx:]
|
||||
|
||||
dirs = ensure_dirs(out_root)
|
||||
|
||||
def process(image: CocoImage, split: str) -> None:
|
||||
src_path = image_dir / image.file_name
|
||||
dst_img_dir = dirs["images_train"] if split == "train" else dirs["images_val"]
|
||||
dst_lbl_dir = dirs["labels_train"] if split == "train" else dirs["labels_val"]
|
||||
|
||||
dst_img_path = dst_img_dir / image.file_name
|
||||
dst_img_path.write_bytes(src_path.read_bytes())
|
||||
|
||||
anns = image_to_annotations.get(image.id, [])
|
||||
lines: List[str] = []
|
||||
|
||||
for ann in anns:
|
||||
digit_str = ann.property_info.strip()
|
||||
if not digit_str.isdigit():
|
||||
continue
|
||||
digit = int(digit_str)
|
||||
if digit < 0 or digit > 9:
|
||||
continue
|
||||
|
||||
x, y, w, h = ann.bbox
|
||||
x_center = (x + w / 2) / image.width
|
||||
y_center = (y + h / 2) / image.height
|
||||
w_norm = w / image.width
|
||||
h_norm = h / image.height
|
||||
|
||||
lines.append(f"{digit} {x_center:.6f} {y_center:.6f} {w_norm:.6f} {h_norm:.6f}")
|
||||
|
||||
dst_lbl_path = dst_lbl_dir / (image.file_name.rsplit(".", 1)[0] + ".txt")
|
||||
dst_lbl_path.write_text("\n".join(lines), encoding="utf-8")
|
||||
|
||||
for img in train_imgs:
|
||||
process(img, "train")
|
||||
for img in val_imgs:
|
||||
process(img, "val")
|
||||
|
||||
data_yaml = {
|
||||
"path": str(out_root.resolve()),
|
||||
"train": "images/train",
|
||||
"val": "images/val",
|
||||
"names": {i: str(i) for i in range(10)},
|
||||
}
|
||||
|
||||
data_yaml_path = out_root / "dataset.yaml"
|
||||
with data_yaml_path.open("w", encoding="utf-8") as f:
|
||||
f.write("# auto-generated by scripts/prepare_yolo_dataset.py\n")
|
||||
for key, value in data_yaml.items():
|
||||
if isinstance(value, dict):
|
||||
f.write(f"{key}:\n")
|
||||
for k, v in value.items():
|
||||
f.write(f" {k}: {v}\n")
|
||||
else:
|
||||
f.write(f"{key}: {value}\n")
|
||||
|
||||
print(f"YOLO dataset prepared at {out_root}")
|
||||
print(f"Train images: {len(train_imgs)}, Val images: {len(val_imgs)}")
|
||||
print(f"Data config written to: {data_yaml_path}")
|
||||
return data_yaml_path
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
image_dir = args.root / "images"
|
||||
if not image_dir.exists():
|
||||
raise FileNotFoundError(f"Images directory not found at {image_dir}")
|
||||
|
||||
images, annotations = load_coco(args.root)
|
||||
coco_to_yolo(images, annotations, image_dir, args.out, args.val_ratio, args.seed)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user