first commit
This commit is contained in:
18
rtdetr_paddle/ppdet/data/source/__init__.py
Normal file
18
rtdetr_paddle/ppdet/data/source/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .coco import *
|
||||
from .voc import *
|
||||
from .category import *
|
||||
from .dataset import ImageFolder
|
||||
926
rtdetr_paddle/ppdet/data/source/category.py
Normal file
926
rtdetr_paddle/ppdet/data/source/category.py
Normal file
@@ -0,0 +1,926 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from ppdet.data.source.voc import pascalvoc_label
|
||||
from ppdet.utils.logger import setup_logger
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
__all__ = ['get_categories']
|
||||
|
||||
|
||||
def get_categories(metric_type, anno_file=None, arch=None):
|
||||
"""
|
||||
Get class id to category id map and category id
|
||||
to category name map from annotation file.
|
||||
|
||||
Args:
|
||||
metric_type (str): metric type, currently support 'coco', 'voc', 'oid'
|
||||
and 'widerface'.
|
||||
anno_file (str): annotation file path
|
||||
"""
|
||||
if arch == 'keypoint_arch':
|
||||
return (None, {'id': 'keypoint'})
|
||||
|
||||
if anno_file == None or (not os.path.isfile(anno_file)):
|
||||
logger.warning(
|
||||
"anno_file '{}' is None or not set or not exist, "
|
||||
"please recheck TrainDataset/EvalDataset/TestDataset.anno_path, "
|
||||
"otherwise the default categories will be used by metric_type.".
|
||||
format(anno_file))
|
||||
|
||||
if metric_type.lower() == 'coco' or metric_type.lower(
|
||||
) == 'rbox' or metric_type.lower() == 'snipercoco':
|
||||
if anno_file and os.path.isfile(anno_file):
|
||||
if anno_file.endswith('json'):
|
||||
# lazy import pycocotools here
|
||||
from pycocotools.coco import COCO
|
||||
coco = COCO(anno_file)
|
||||
cats = coco.loadCats(coco.getCatIds())
|
||||
|
||||
clsid2catid = {i: cat['id'] for i, cat in enumerate(cats)}
|
||||
catid2name = {cat['id']: cat['name'] for cat in cats}
|
||||
|
||||
elif anno_file.endswith('txt'):
|
||||
cats = []
|
||||
with open(anno_file) as f:
|
||||
for line in f.readlines():
|
||||
cats.append(line.strip())
|
||||
if cats[0] == 'background': cats = cats[1:]
|
||||
|
||||
clsid2catid = {i: i for i in range(len(cats))}
|
||||
catid2name = {i: name for i, name in enumerate(cats)}
|
||||
|
||||
else:
|
||||
raise ValueError("anno_file {} should be json or txt.".format(
|
||||
anno_file))
|
||||
return clsid2catid, catid2name
|
||||
|
||||
# anno file not exist, load default categories of COCO17
|
||||
else:
|
||||
if metric_type.lower() == 'rbox':
|
||||
logger.warning(
|
||||
"metric_type: {}, load default categories of DOTA.".format(
|
||||
metric_type))
|
||||
return _dota_category()
|
||||
logger.warning("metric_type: {}, load default categories of COCO.".
|
||||
format(metric_type))
|
||||
return _coco17_category()
|
||||
|
||||
elif metric_type.lower() == 'voc':
|
||||
if anno_file and os.path.isfile(anno_file):
|
||||
cats = []
|
||||
with open(anno_file) as f:
|
||||
for line in f.readlines():
|
||||
cats.append(line.strip())
|
||||
|
||||
if cats[0] == 'background':
|
||||
cats = cats[1:]
|
||||
|
||||
clsid2catid = {i: i for i in range(len(cats))}
|
||||
catid2name = {i: name for i, name in enumerate(cats)}
|
||||
|
||||
return clsid2catid, catid2name
|
||||
|
||||
# anno file not exist, load default categories of
|
||||
# VOC all 20 categories
|
||||
else:
|
||||
logger.warning("metric_type: {}, load default categories of VOC.".
|
||||
format(metric_type))
|
||||
return _vocall_category()
|
||||
|
||||
elif metric_type.lower() == 'oid':
|
||||
if anno_file and os.path.isfile(anno_file):
|
||||
logger.warning("only default categories support for OID19")
|
||||
return _oid19_category()
|
||||
|
||||
elif metric_type.lower() == 'keypointtopdowncocoeval' or metric_type.lower(
|
||||
) == 'keypointtopdownmpiieval':
|
||||
return (None, {'id': 'keypoint'})
|
||||
|
||||
elif metric_type.lower() == 'pose3deval':
|
||||
return (None, {'id': 'pose3d'})
|
||||
|
||||
elif metric_type.lower() in ['mot', 'motdet', 'reid']:
|
||||
if anno_file and os.path.isfile(anno_file):
|
||||
cats = []
|
||||
with open(anno_file) as f:
|
||||
for line in f.readlines():
|
||||
cats.append(line.strip())
|
||||
if cats[0] == 'background':
|
||||
cats = cats[1:]
|
||||
clsid2catid = {i: i for i in range(len(cats))}
|
||||
catid2name = {i: name for i, name in enumerate(cats)}
|
||||
return clsid2catid, catid2name
|
||||
# anno file not exist, load default category 'pedestrian'.
|
||||
else:
|
||||
logger.warning(
|
||||
"metric_type: {}, load default categories of pedestrian MOT.".
|
||||
format(metric_type))
|
||||
return _mot_category(category='pedestrian')
|
||||
|
||||
elif metric_type.lower() in ['kitti', 'bdd100kmot']:
|
||||
return _mot_category(category='vehicle')
|
||||
|
||||
elif metric_type.lower() in ['mcmot']:
|
||||
if anno_file and os.path.isfile(anno_file):
|
||||
cats = []
|
||||
with open(anno_file) as f:
|
||||
for line in f.readlines():
|
||||
cats.append(line.strip())
|
||||
if cats[0] == 'background':
|
||||
cats = cats[1:]
|
||||
clsid2catid = {i: i for i in range(len(cats))}
|
||||
catid2name = {i: name for i, name in enumerate(cats)}
|
||||
return clsid2catid, catid2name
|
||||
# anno file not exist, load default categories of visdrone all 10 categories
|
||||
else:
|
||||
logger.warning(
|
||||
"metric_type: {}, load default categories of VisDrone.".format(
|
||||
metric_type))
|
||||
return _visdrone_category()
|
||||
|
||||
else:
|
||||
raise ValueError("unknown metric type {}".format(metric_type))
|
||||
|
||||
|
||||
def _mot_category(category='pedestrian'):
|
||||
"""
|
||||
Get class id to category id map and category id
|
||||
to category name map of mot dataset
|
||||
"""
|
||||
label_map = {category: 0}
|
||||
label_map = sorted(label_map.items(), key=lambda x: x[1])
|
||||
cats = [l[0] for l in label_map]
|
||||
|
||||
clsid2catid = {i: i for i in range(len(cats))}
|
||||
catid2name = {i: name for i, name in enumerate(cats)}
|
||||
|
||||
return clsid2catid, catid2name
|
||||
|
||||
|
||||
def _coco17_category():
|
||||
"""
|
||||
Get class id to category id map and category id
|
||||
to category name map of COCO2017 dataset
|
||||
|
||||
"""
|
||||
clsid2catid = {
|
||||
1: 1,
|
||||
2: 2,
|
||||
3: 3,
|
||||
4: 4,
|
||||
5: 5,
|
||||
6: 6,
|
||||
7: 7,
|
||||
8: 8,
|
||||
9: 9,
|
||||
10: 10,
|
||||
11: 11,
|
||||
12: 13,
|
||||
13: 14,
|
||||
14: 15,
|
||||
15: 16,
|
||||
16: 17,
|
||||
17: 18,
|
||||
18: 19,
|
||||
19: 20,
|
||||
20: 21,
|
||||
21: 22,
|
||||
22: 23,
|
||||
23: 24,
|
||||
24: 25,
|
||||
25: 27,
|
||||
26: 28,
|
||||
27: 31,
|
||||
28: 32,
|
||||
29: 33,
|
||||
30: 34,
|
||||
31: 35,
|
||||
32: 36,
|
||||
33: 37,
|
||||
34: 38,
|
||||
35: 39,
|
||||
36: 40,
|
||||
37: 41,
|
||||
38: 42,
|
||||
39: 43,
|
||||
40: 44,
|
||||
41: 46,
|
||||
42: 47,
|
||||
43: 48,
|
||||
44: 49,
|
||||
45: 50,
|
||||
46: 51,
|
||||
47: 52,
|
||||
48: 53,
|
||||
49: 54,
|
||||
50: 55,
|
||||
51: 56,
|
||||
52: 57,
|
||||
53: 58,
|
||||
54: 59,
|
||||
55: 60,
|
||||
56: 61,
|
||||
57: 62,
|
||||
58: 63,
|
||||
59: 64,
|
||||
60: 65,
|
||||
61: 67,
|
||||
62: 70,
|
||||
63: 72,
|
||||
64: 73,
|
||||
65: 74,
|
||||
66: 75,
|
||||
67: 76,
|
||||
68: 77,
|
||||
69: 78,
|
||||
70: 79,
|
||||
71: 80,
|
||||
72: 81,
|
||||
73: 82,
|
||||
74: 84,
|
||||
75: 85,
|
||||
76: 86,
|
||||
77: 87,
|
||||
78: 88,
|
||||
79: 89,
|
||||
80: 90
|
||||
}
|
||||
|
||||
catid2name = {
|
||||
0: 'background',
|
||||
1: 'person',
|
||||
2: 'bicycle',
|
||||
3: 'car',
|
||||
4: 'motorcycle',
|
||||
5: 'airplane',
|
||||
6: 'bus',
|
||||
7: 'train',
|
||||
8: 'truck',
|
||||
9: 'boat',
|
||||
10: 'traffic light',
|
||||
11: 'fire hydrant',
|
||||
13: 'stop sign',
|
||||
14: 'parking meter',
|
||||
15: 'bench',
|
||||
16: 'bird',
|
||||
17: 'cat',
|
||||
18: 'dog',
|
||||
19: 'horse',
|
||||
20: 'sheep',
|
||||
21: 'cow',
|
||||
22: 'elephant',
|
||||
23: 'bear',
|
||||
24: 'zebra',
|
||||
25: 'giraffe',
|
||||
27: 'backpack',
|
||||
28: 'umbrella',
|
||||
31: 'handbag',
|
||||
32: 'tie',
|
||||
33: 'suitcase',
|
||||
34: 'frisbee',
|
||||
35: 'skis',
|
||||
36: 'snowboard',
|
||||
37: 'sports ball',
|
||||
38: 'kite',
|
||||
39: 'baseball bat',
|
||||
40: 'baseball glove',
|
||||
41: 'skateboard',
|
||||
42: 'surfboard',
|
||||
43: 'tennis racket',
|
||||
44: 'bottle',
|
||||
46: 'wine glass',
|
||||
47: 'cup',
|
||||
48: 'fork',
|
||||
49: 'knife',
|
||||
50: 'spoon',
|
||||
51: 'bowl',
|
||||
52: 'banana',
|
||||
53: 'apple',
|
||||
54: 'sandwich',
|
||||
55: 'orange',
|
||||
56: 'broccoli',
|
||||
57: 'carrot',
|
||||
58: 'hot dog',
|
||||
59: 'pizza',
|
||||
60: 'donut',
|
||||
61: 'cake',
|
||||
62: 'chair',
|
||||
63: 'couch',
|
||||
64: 'potted plant',
|
||||
65: 'bed',
|
||||
67: 'dining table',
|
||||
70: 'toilet',
|
||||
72: 'tv',
|
||||
73: 'laptop',
|
||||
74: 'mouse',
|
||||
75: 'remote',
|
||||
76: 'keyboard',
|
||||
77: 'cell phone',
|
||||
78: 'microwave',
|
||||
79: 'oven',
|
||||
80: 'toaster',
|
||||
81: 'sink',
|
||||
82: 'refrigerator',
|
||||
84: 'book',
|
||||
85: 'clock',
|
||||
86: 'vase',
|
||||
87: 'scissors',
|
||||
88: 'teddy bear',
|
||||
89: 'hair drier',
|
||||
90: 'toothbrush'
|
||||
}
|
||||
|
||||
clsid2catid = {k - 1: v for k, v in clsid2catid.items()}
|
||||
catid2name.pop(0)
|
||||
|
||||
return clsid2catid, catid2name
|
||||
|
||||
|
||||
def _dota_category():
|
||||
"""
|
||||
Get class id to category id map and category id
|
||||
to category name map of dota dataset
|
||||
"""
|
||||
catid2name = {
|
||||
0: 'background',
|
||||
1: 'plane',
|
||||
2: 'baseball-diamond',
|
||||
3: 'bridge',
|
||||
4: 'ground-track-field',
|
||||
5: 'small-vehicle',
|
||||
6: 'large-vehicle',
|
||||
7: 'ship',
|
||||
8: 'tennis-court',
|
||||
9: 'basketball-court',
|
||||
10: 'storage-tank',
|
||||
11: 'soccer-ball-field',
|
||||
12: 'roundabout',
|
||||
13: 'harbor',
|
||||
14: 'swimming-pool',
|
||||
15: 'helicopter'
|
||||
}
|
||||
catid2name.pop(0)
|
||||
clsid2catid = {i: i + 1 for i in range(len(catid2name))}
|
||||
return clsid2catid, catid2name
|
||||
|
||||
|
||||
def _vocall_category():
|
||||
"""
|
||||
Get class id to category id map and category id
|
||||
to category name map of mixup voc dataset
|
||||
|
||||
"""
|
||||
label_map = pascalvoc_label()
|
||||
label_map = sorted(label_map.items(), key=lambda x: x[1])
|
||||
cats = [l[0] for l in label_map]
|
||||
|
||||
clsid2catid = {i: i for i in range(len(cats))}
|
||||
catid2name = {i: name for i, name in enumerate(cats)}
|
||||
|
||||
return clsid2catid, catid2name
|
||||
|
||||
|
||||
def _oid19_category():
|
||||
clsid2catid = {k: k + 1 for k in range(500)}
|
||||
|
||||
catid2name = {
|
||||
0: "background",
|
||||
1: "Infant bed",
|
||||
2: "Rose",
|
||||
3: "Flag",
|
||||
4: "Flashlight",
|
||||
5: "Sea turtle",
|
||||
6: "Camera",
|
||||
7: "Animal",
|
||||
8: "Glove",
|
||||
9: "Crocodile",
|
||||
10: "Cattle",
|
||||
11: "House",
|
||||
12: "Guacamole",
|
||||
13: "Penguin",
|
||||
14: "Vehicle registration plate",
|
||||
15: "Bench",
|
||||
16: "Ladybug",
|
||||
17: "Human nose",
|
||||
18: "Watermelon",
|
||||
19: "Flute",
|
||||
20: "Butterfly",
|
||||
21: "Washing machine",
|
||||
22: "Raccoon",
|
||||
23: "Segway",
|
||||
24: "Taco",
|
||||
25: "Jellyfish",
|
||||
26: "Cake",
|
||||
27: "Pen",
|
||||
28: "Cannon",
|
||||
29: "Bread",
|
||||
30: "Tree",
|
||||
31: "Shellfish",
|
||||
32: "Bed",
|
||||
33: "Hamster",
|
||||
34: "Hat",
|
||||
35: "Toaster",
|
||||
36: "Sombrero",
|
||||
37: "Tiara",
|
||||
38: "Bowl",
|
||||
39: "Dragonfly",
|
||||
40: "Moths and butterflies",
|
||||
41: "Antelope",
|
||||
42: "Vegetable",
|
||||
43: "Torch",
|
||||
44: "Building",
|
||||
45: "Power plugs and sockets",
|
||||
46: "Blender",
|
||||
47: "Billiard table",
|
||||
48: "Cutting board",
|
||||
49: "Bronze sculpture",
|
||||
50: "Turtle",
|
||||
51: "Broccoli",
|
||||
52: "Tiger",
|
||||
53: "Mirror",
|
||||
54: "Bear",
|
||||
55: "Zucchini",
|
||||
56: "Dress",
|
||||
57: "Volleyball",
|
||||
58: "Guitar",
|
||||
59: "Reptile",
|
||||
60: "Golf cart",
|
||||
61: "Tart",
|
||||
62: "Fedora",
|
||||
63: "Carnivore",
|
||||
64: "Car",
|
||||
65: "Lighthouse",
|
||||
66: "Coffeemaker",
|
||||
67: "Food processor",
|
||||
68: "Truck",
|
||||
69: "Bookcase",
|
||||
70: "Surfboard",
|
||||
71: "Footwear",
|
||||
72: "Bench",
|
||||
73: "Necklace",
|
||||
74: "Flower",
|
||||
75: "Radish",
|
||||
76: "Marine mammal",
|
||||
77: "Frying pan",
|
||||
78: "Tap",
|
||||
79: "Peach",
|
||||
80: "Knife",
|
||||
81: "Handbag",
|
||||
82: "Laptop",
|
||||
83: "Tent",
|
||||
84: "Ambulance",
|
||||
85: "Christmas tree",
|
||||
86: "Eagle",
|
||||
87: "Limousine",
|
||||
88: "Kitchen & dining room table",
|
||||
89: "Polar bear",
|
||||
90: "Tower",
|
||||
91: "Football",
|
||||
92: "Willow",
|
||||
93: "Human head",
|
||||
94: "Stop sign",
|
||||
95: "Banana",
|
||||
96: "Mixer",
|
||||
97: "Binoculars",
|
||||
98: "Dessert",
|
||||
99: "Bee",
|
||||
100: "Chair",
|
||||
101: "Wood-burning stove",
|
||||
102: "Flowerpot",
|
||||
103: "Beaker",
|
||||
104: "Oyster",
|
||||
105: "Woodpecker",
|
||||
106: "Harp",
|
||||
107: "Bathtub",
|
||||
108: "Wall clock",
|
||||
109: "Sports uniform",
|
||||
110: "Rhinoceros",
|
||||
111: "Beehive",
|
||||
112: "Cupboard",
|
||||
113: "Chicken",
|
||||
114: "Man",
|
||||
115: "Blue jay",
|
||||
116: "Cucumber",
|
||||
117: "Balloon",
|
||||
118: "Kite",
|
||||
119: "Fireplace",
|
||||
120: "Lantern",
|
||||
121: "Missile",
|
||||
122: "Book",
|
||||
123: "Spoon",
|
||||
124: "Grapefruit",
|
||||
125: "Squirrel",
|
||||
126: "Orange",
|
||||
127: "Coat",
|
||||
128: "Punching bag",
|
||||
129: "Zebra",
|
||||
130: "Billboard",
|
||||
131: "Bicycle",
|
||||
132: "Door handle",
|
||||
133: "Mechanical fan",
|
||||
134: "Ring binder",
|
||||
135: "Table",
|
||||
136: "Parrot",
|
||||
137: "Sock",
|
||||
138: "Vase",
|
||||
139: "Weapon",
|
||||
140: "Shotgun",
|
||||
141: "Glasses",
|
||||
142: "Seahorse",
|
||||
143: "Belt",
|
||||
144: "Watercraft",
|
||||
145: "Window",
|
||||
146: "Giraffe",
|
||||
147: "Lion",
|
||||
148: "Tire",
|
||||
149: "Vehicle",
|
||||
150: "Canoe",
|
||||
151: "Tie",
|
||||
152: "Shelf",
|
||||
153: "Picture frame",
|
||||
154: "Printer",
|
||||
155: "Human leg",
|
||||
156: "Boat",
|
||||
157: "Slow cooker",
|
||||
158: "Croissant",
|
||||
159: "Candle",
|
||||
160: "Pancake",
|
||||
161: "Pillow",
|
||||
162: "Coin",
|
||||
163: "Stretcher",
|
||||
164: "Sandal",
|
||||
165: "Woman",
|
||||
166: "Stairs",
|
||||
167: "Harpsichord",
|
||||
168: "Stool",
|
||||
169: "Bus",
|
||||
170: "Suitcase",
|
||||
171: "Human mouth",
|
||||
172: "Juice",
|
||||
173: "Skull",
|
||||
174: "Door",
|
||||
175: "Violin",
|
||||
176: "Chopsticks",
|
||||
177: "Digital clock",
|
||||
178: "Sunflower",
|
||||
179: "Leopard",
|
||||
180: "Bell pepper",
|
||||
181: "Harbor seal",
|
||||
182: "Snake",
|
||||
183: "Sewing machine",
|
||||
184: "Goose",
|
||||
185: "Helicopter",
|
||||
186: "Seat belt",
|
||||
187: "Coffee cup",
|
||||
188: "Microwave oven",
|
||||
189: "Hot dog",
|
||||
190: "Countertop",
|
||||
191: "Serving tray",
|
||||
192: "Dog bed",
|
||||
193: "Beer",
|
||||
194: "Sunglasses",
|
||||
195: "Golf ball",
|
||||
196: "Waffle",
|
||||
197: "Palm tree",
|
||||
198: "Trumpet",
|
||||
199: "Ruler",
|
||||
200: "Helmet",
|
||||
201: "Ladder",
|
||||
202: "Office building",
|
||||
203: "Tablet computer",
|
||||
204: "Toilet paper",
|
||||
205: "Pomegranate",
|
||||
206: "Skirt",
|
||||
207: "Gas stove",
|
||||
208: "Cookie",
|
||||
209: "Cart",
|
||||
210: "Raven",
|
||||
211: "Egg",
|
||||
212: "Burrito",
|
||||
213: "Goat",
|
||||
214: "Kitchen knife",
|
||||
215: "Skateboard",
|
||||
216: "Salt and pepper shakers",
|
||||
217: "Lynx",
|
||||
218: "Boot",
|
||||
219: "Platter",
|
||||
220: "Ski",
|
||||
221: "Swimwear",
|
||||
222: "Swimming pool",
|
||||
223: "Drinking straw",
|
||||
224: "Wrench",
|
||||
225: "Drum",
|
||||
226: "Ant",
|
||||
227: "Human ear",
|
||||
228: "Headphones",
|
||||
229: "Fountain",
|
||||
230: "Bird",
|
||||
231: "Jeans",
|
||||
232: "Television",
|
||||
233: "Crab",
|
||||
234: "Microphone",
|
||||
235: "Home appliance",
|
||||
236: "Snowplow",
|
||||
237: "Beetle",
|
||||
238: "Artichoke",
|
||||
239: "Jet ski",
|
||||
240: "Stationary bicycle",
|
||||
241: "Human hair",
|
||||
242: "Brown bear",
|
||||
243: "Starfish",
|
||||
244: "Fork",
|
||||
245: "Lobster",
|
||||
246: "Corded phone",
|
||||
247: "Drink",
|
||||
248: "Saucer",
|
||||
249: "Carrot",
|
||||
250: "Insect",
|
||||
251: "Clock",
|
||||
252: "Castle",
|
||||
253: "Tennis racket",
|
||||
254: "Ceiling fan",
|
||||
255: "Asparagus",
|
||||
256: "Jaguar",
|
||||
257: "Musical instrument",
|
||||
258: "Train",
|
||||
259: "Cat",
|
||||
260: "Rifle",
|
||||
261: "Dumbbell",
|
||||
262: "Mobile phone",
|
||||
263: "Taxi",
|
||||
264: "Shower",
|
||||
265: "Pitcher",
|
||||
266: "Lemon",
|
||||
267: "Invertebrate",
|
||||
268: "Turkey",
|
||||
269: "High heels",
|
||||
270: "Bust",
|
||||
271: "Elephant",
|
||||
272: "Scarf",
|
||||
273: "Barrel",
|
||||
274: "Trombone",
|
||||
275: "Pumpkin",
|
||||
276: "Box",
|
||||
277: "Tomato",
|
||||
278: "Frog",
|
||||
279: "Bidet",
|
||||
280: "Human face",
|
||||
281: "Houseplant",
|
||||
282: "Van",
|
||||
283: "Shark",
|
||||
284: "Ice cream",
|
||||
285: "Swim cap",
|
||||
286: "Falcon",
|
||||
287: "Ostrich",
|
||||
288: "Handgun",
|
||||
289: "Whiteboard",
|
||||
290: "Lizard",
|
||||
291: "Pasta",
|
||||
292: "Snowmobile",
|
||||
293: "Light bulb",
|
||||
294: "Window blind",
|
||||
295: "Muffin",
|
||||
296: "Pretzel",
|
||||
297: "Computer monitor",
|
||||
298: "Horn",
|
||||
299: "Furniture",
|
||||
300: "Sandwich",
|
||||
301: "Fox",
|
||||
302: "Convenience store",
|
||||
303: "Fish",
|
||||
304: "Fruit",
|
||||
305: "Earrings",
|
||||
306: "Curtain",
|
||||
307: "Grape",
|
||||
308: "Sofa bed",
|
||||
309: "Horse",
|
||||
310: "Luggage and bags",
|
||||
311: "Desk",
|
||||
312: "Crutch",
|
||||
313: "Bicycle helmet",
|
||||
314: "Tick",
|
||||
315: "Airplane",
|
||||
316: "Canary",
|
||||
317: "Spatula",
|
||||
318: "Watch",
|
||||
319: "Lily",
|
||||
320: "Kitchen appliance",
|
||||
321: "Filing cabinet",
|
||||
322: "Aircraft",
|
||||
323: "Cake stand",
|
||||
324: "Candy",
|
||||
325: "Sink",
|
||||
326: "Mouse",
|
||||
327: "Wine",
|
||||
328: "Wheelchair",
|
||||
329: "Goldfish",
|
||||
330: "Refrigerator",
|
||||
331: "French fries",
|
||||
332: "Drawer",
|
||||
333: "Treadmill",
|
||||
334: "Picnic basket",
|
||||
335: "Dice",
|
||||
336: "Cabbage",
|
||||
337: "Football helmet",
|
||||
338: "Pig",
|
||||
339: "Person",
|
||||
340: "Shorts",
|
||||
341: "Gondola",
|
||||
342: "Honeycomb",
|
||||
343: "Doughnut",
|
||||
344: "Chest of drawers",
|
||||
345: "Land vehicle",
|
||||
346: "Bat",
|
||||
347: "Monkey",
|
||||
348: "Dagger",
|
||||
349: "Tableware",
|
||||
350: "Human foot",
|
||||
351: "Mug",
|
||||
352: "Alarm clock",
|
||||
353: "Pressure cooker",
|
||||
354: "Human hand",
|
||||
355: "Tortoise",
|
||||
356: "Baseball glove",
|
||||
357: "Sword",
|
||||
358: "Pear",
|
||||
359: "Miniskirt",
|
||||
360: "Traffic sign",
|
||||
361: "Girl",
|
||||
362: "Roller skates",
|
||||
363: "Dinosaur",
|
||||
364: "Porch",
|
||||
365: "Human beard",
|
||||
366: "Submarine sandwich",
|
||||
367: "Screwdriver",
|
||||
368: "Strawberry",
|
||||
369: "Wine glass",
|
||||
370: "Seafood",
|
||||
371: "Racket",
|
||||
372: "Wheel",
|
||||
373: "Sea lion",
|
||||
374: "Toy",
|
||||
375: "Tea",
|
||||
376: "Tennis ball",
|
||||
377: "Waste container",
|
||||
378: "Mule",
|
||||
379: "Cricket ball",
|
||||
380: "Pineapple",
|
||||
381: "Coconut",
|
||||
382: "Doll",
|
||||
383: "Coffee table",
|
||||
384: "Snowman",
|
||||
385: "Lavender",
|
||||
386: "Shrimp",
|
||||
387: "Maple",
|
||||
388: "Cowboy hat",
|
||||
389: "Goggles",
|
||||
390: "Rugby ball",
|
||||
391: "Caterpillar",
|
||||
392: "Poster",
|
||||
393: "Rocket",
|
||||
394: "Organ",
|
||||
395: "Saxophone",
|
||||
396: "Traffic light",
|
||||
397: "Cocktail",
|
||||
398: "Plastic bag",
|
||||
399: "Squash",
|
||||
400: "Mushroom",
|
||||
401: "Hamburger",
|
||||
402: "Light switch",
|
||||
403: "Parachute",
|
||||
404: "Teddy bear",
|
||||
405: "Winter melon",
|
||||
406: "Deer",
|
||||
407: "Musical keyboard",
|
||||
408: "Plumbing fixture",
|
||||
409: "Scoreboard",
|
||||
410: "Baseball bat",
|
||||
411: "Envelope",
|
||||
412: "Adhesive tape",
|
||||
413: "Briefcase",
|
||||
414: "Paddle",
|
||||
415: "Bow and arrow",
|
||||
416: "Telephone",
|
||||
417: "Sheep",
|
||||
418: "Jacket",
|
||||
419: "Boy",
|
||||
420: "Pizza",
|
||||
421: "Otter",
|
||||
422: "Office supplies",
|
||||
423: "Couch",
|
||||
424: "Cello",
|
||||
425: "Bull",
|
||||
426: "Camel",
|
||||
427: "Ball",
|
||||
428: "Duck",
|
||||
429: "Whale",
|
||||
430: "Shirt",
|
||||
431: "Tank",
|
||||
432: "Motorcycle",
|
||||
433: "Accordion",
|
||||
434: "Owl",
|
||||
435: "Porcupine",
|
||||
436: "Sun hat",
|
||||
437: "Nail",
|
||||
438: "Scissors",
|
||||
439: "Swan",
|
||||
440: "Lamp",
|
||||
441: "Crown",
|
||||
442: "Piano",
|
||||
443: "Sculpture",
|
||||
444: "Cheetah",
|
||||
445: "Oboe",
|
||||
446: "Tin can",
|
||||
447: "Mango",
|
||||
448: "Tripod",
|
||||
449: "Oven",
|
||||
450: "Mouse",
|
||||
451: "Barge",
|
||||
452: "Coffee",
|
||||
453: "Snowboard",
|
||||
454: "Common fig",
|
||||
455: "Salad",
|
||||
456: "Marine invertebrates",
|
||||
457: "Umbrella",
|
||||
458: "Kangaroo",
|
||||
459: "Human arm",
|
||||
460: "Measuring cup",
|
||||
461: "Snail",
|
||||
462: "Loveseat",
|
||||
463: "Suit",
|
||||
464: "Teapot",
|
||||
465: "Bottle",
|
||||
466: "Alpaca",
|
||||
467: "Kettle",
|
||||
468: "Trousers",
|
||||
469: "Popcorn",
|
||||
470: "Centipede",
|
||||
471: "Spider",
|
||||
472: "Sparrow",
|
||||
473: "Plate",
|
||||
474: "Bagel",
|
||||
475: "Personal care",
|
||||
476: "Apple",
|
||||
477: "Brassiere",
|
||||
478: "Bathroom cabinet",
|
||||
479: "studio couch",
|
||||
480: "Computer keyboard",
|
||||
481: "Table tennis racket",
|
||||
482: "Sushi",
|
||||
483: "Cabinetry",
|
||||
484: "Street light",
|
||||
485: "Towel",
|
||||
486: "Nightstand",
|
||||
487: "Rabbit",
|
||||
488: "Dolphin",
|
||||
489: "Dog",
|
||||
490: "Jug",
|
||||
491: "Wok",
|
||||
492: "Fire hydrant",
|
||||
493: "Human eye",
|
||||
494: "Skyscraper",
|
||||
495: "Backpack",
|
||||
496: "Potato",
|
||||
497: "Paper towel",
|
||||
498: "Lifejacket",
|
||||
499: "Bicycle wheel",
|
||||
500: "Toilet",
|
||||
}
|
||||
|
||||
return clsid2catid, catid2name
|
||||
|
||||
|
||||
def _visdrone_category():
|
||||
clsid2catid = {i: i for i in range(10)}
|
||||
|
||||
catid2name = {
|
||||
0: 'pedestrian',
|
||||
1: 'people',
|
||||
2: 'bicycle',
|
||||
3: 'car',
|
||||
4: 'van',
|
||||
5: 'truck',
|
||||
6: 'tricycle',
|
||||
7: 'awning-tricycle',
|
||||
8: 'bus',
|
||||
9: 'motor'
|
||||
}
|
||||
return clsid2catid, catid2name
|
||||
587
rtdetr_paddle/ppdet/data/source/coco.py
Normal file
587
rtdetr_paddle/ppdet/data/source/coco.py
Normal file
@@ -0,0 +1,587 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import copy
|
||||
try:
|
||||
from collections.abc import Sequence
|
||||
except Exception:
|
||||
from collections import Sequence
|
||||
import numpy as np
|
||||
from ppdet.core.workspace import register, serializable
|
||||
from .dataset import DetDataset
|
||||
|
||||
from ppdet.utils.logger import setup_logger
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
__all__ = ['COCODataSet', 'SlicedCOCODataSet', 'SemiCOCODataSet']
|
||||
|
||||
|
||||
@register
|
||||
@serializable
|
||||
class COCODataSet(DetDataset):
|
||||
"""
|
||||
Load dataset with COCO format.
|
||||
|
||||
Args:
|
||||
dataset_dir (str): root directory for dataset.
|
||||
image_dir (str): directory for images.
|
||||
anno_path (str): coco annotation file path.
|
||||
data_fields (list): key name of data dictionary, at least have 'image'.
|
||||
sample_num (int): number of samples to load, -1 means all.
|
||||
load_crowd (bool): whether to load crowded ground-truth.
|
||||
False as default
|
||||
allow_empty (bool): whether to load empty entry. False as default
|
||||
empty_ratio (float): the ratio of empty record number to total
|
||||
record's, if empty_ratio is out of [0. ,1.), do not sample the
|
||||
records and use all the empty entries. 1. as default
|
||||
repeat (int): repeat times for dataset, use in benchmark.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dataset_dir=None,
|
||||
image_dir=None,
|
||||
anno_path=None,
|
||||
data_fields=['image'],
|
||||
sample_num=-1,
|
||||
load_crowd=False,
|
||||
allow_empty=False,
|
||||
empty_ratio=1.,
|
||||
repeat=1):
|
||||
super(COCODataSet, self).__init__(
|
||||
dataset_dir,
|
||||
image_dir,
|
||||
anno_path,
|
||||
data_fields,
|
||||
sample_num,
|
||||
repeat=repeat)
|
||||
self.load_image_only = False
|
||||
self.load_semantic = False
|
||||
self.load_crowd = load_crowd
|
||||
self.allow_empty = allow_empty
|
||||
self.empty_ratio = empty_ratio
|
||||
|
||||
def _sample_empty(self, records, num):
|
||||
# if empty_ratio is out of [0. ,1.), do not sample the records
|
||||
if self.empty_ratio < 0. or self.empty_ratio >= 1.:
|
||||
return records
|
||||
import random
|
||||
sample_num = min(
|
||||
int(num * self.empty_ratio / (1 - self.empty_ratio)), len(records))
|
||||
records = random.sample(records, sample_num)
|
||||
return records
|
||||
|
||||
def parse_dataset(self):
|
||||
anno_path = os.path.join(self.dataset_dir, self.anno_path)
|
||||
image_dir = os.path.join(self.dataset_dir, self.image_dir)
|
||||
|
||||
assert anno_path.endswith('.json'), \
|
||||
'invalid coco annotation file: ' + anno_path
|
||||
from pycocotools.coco import COCO
|
||||
coco = COCO(anno_path)
|
||||
img_ids = coco.getImgIds()
|
||||
img_ids.sort()
|
||||
cat_ids = coco.getCatIds()
|
||||
records = []
|
||||
empty_records = []
|
||||
ct = 0
|
||||
|
||||
self.catid2clsid = dict({catid: i for i, catid in enumerate(cat_ids)})
|
||||
self.cname2cid = dict({
|
||||
coco.loadCats(catid)[0]['name']: clsid
|
||||
for catid, clsid in self.catid2clsid.items()
|
||||
})
|
||||
|
||||
if 'annotations' not in coco.dataset:
|
||||
self.load_image_only = True
|
||||
logger.warning('Annotation file: {} does not contains ground truth '
|
||||
'and load image information only.'.format(anno_path))
|
||||
|
||||
for img_id in img_ids:
|
||||
img_anno = coco.loadImgs([img_id])[0]
|
||||
im_fname = img_anno['file_name']
|
||||
im_w = float(img_anno['width'])
|
||||
im_h = float(img_anno['height'])
|
||||
|
||||
im_path = os.path.join(image_dir,
|
||||
im_fname) if image_dir else im_fname
|
||||
is_empty = False
|
||||
if not os.path.exists(im_path):
|
||||
logger.warning('Illegal image file: {}, and it will be '
|
||||
'ignored'.format(im_path))
|
||||
continue
|
||||
|
||||
if im_w < 0 or im_h < 0:
|
||||
logger.warning('Illegal width: {} or height: {} in annotation, '
|
||||
'and im_id: {} will be ignored'.format(
|
||||
im_w, im_h, img_id))
|
||||
continue
|
||||
|
||||
coco_rec = {
|
||||
'im_file': im_path,
|
||||
'im_id': np.array([img_id]),
|
||||
'h': im_h,
|
||||
'w': im_w,
|
||||
} if 'image' in self.data_fields else {}
|
||||
|
||||
if not self.load_image_only:
|
||||
ins_anno_ids = coco.getAnnIds(
|
||||
imgIds=[img_id], iscrowd=None if self.load_crowd else False)
|
||||
instances = coco.loadAnns(ins_anno_ids)
|
||||
|
||||
bboxes = []
|
||||
is_rbox_anno = False
|
||||
for inst in instances:
|
||||
# check gt bbox
|
||||
if inst.get('ignore', False):
|
||||
continue
|
||||
if 'bbox' not in inst.keys():
|
||||
continue
|
||||
else:
|
||||
if not any(np.array(inst['bbox'])):
|
||||
continue
|
||||
|
||||
x1, y1, box_w, box_h = inst['bbox']
|
||||
x2 = x1 + box_w
|
||||
y2 = y1 + box_h
|
||||
eps = 1e-5
|
||||
if inst['area'] > 0 and x2 - x1 > eps and y2 - y1 > eps:
|
||||
inst['clean_bbox'] = [
|
||||
round(float(x), 3) for x in [x1, y1, x2, y2]
|
||||
]
|
||||
bboxes.append(inst)
|
||||
else:
|
||||
logger.warning(
|
||||
'Found an invalid bbox in annotations: im_id: {}, '
|
||||
'area: {} x1: {}, y1: {}, x2: {}, y2: {}.'.format(
|
||||
img_id, float(inst['area']), x1, y1, x2, y2))
|
||||
|
||||
num_bbox = len(bboxes)
|
||||
if num_bbox <= 0 and not self.allow_empty:
|
||||
continue
|
||||
elif num_bbox <= 0:
|
||||
is_empty = True
|
||||
|
||||
gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32)
|
||||
gt_class = np.zeros((num_bbox, 1), dtype=np.int32)
|
||||
is_crowd = np.zeros((num_bbox, 1), dtype=np.int32)
|
||||
gt_poly = [None] * num_bbox
|
||||
gt_track_id = -np.ones((num_bbox, 1), dtype=np.int32)
|
||||
|
||||
has_segmentation = False
|
||||
has_track_id = False
|
||||
for i, box in enumerate(bboxes):
|
||||
catid = box['category_id']
|
||||
gt_class[i][0] = self.catid2clsid[catid]
|
||||
gt_bbox[i, :] = box['clean_bbox']
|
||||
is_crowd[i][0] = box['iscrowd']
|
||||
# check RLE format
|
||||
if 'segmentation' in box and box['iscrowd'] == 1:
|
||||
gt_poly[i] = [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]
|
||||
elif 'segmentation' in box and box['segmentation']:
|
||||
if not np.array(
|
||||
box['segmentation'],
|
||||
dtype=object).size > 0 and not self.allow_empty:
|
||||
bboxes.pop(i)
|
||||
gt_poly.pop(i)
|
||||
np.delete(is_crowd, i)
|
||||
np.delete(gt_class, i)
|
||||
np.delete(gt_bbox, i)
|
||||
else:
|
||||
gt_poly[i] = box['segmentation']
|
||||
has_segmentation = True
|
||||
|
||||
if 'track_id' in box:
|
||||
gt_track_id[i][0] = box['track_id']
|
||||
has_track_id = True
|
||||
|
||||
if has_segmentation and not any(
|
||||
gt_poly) and not self.allow_empty:
|
||||
continue
|
||||
|
||||
gt_rec = {
|
||||
'is_crowd': is_crowd,
|
||||
'gt_class': gt_class,
|
||||
'gt_bbox': gt_bbox,
|
||||
'gt_poly': gt_poly,
|
||||
}
|
||||
if has_track_id:
|
||||
gt_rec.update({'gt_track_id': gt_track_id})
|
||||
|
||||
for k, v in gt_rec.items():
|
||||
if k in self.data_fields:
|
||||
coco_rec[k] = v
|
||||
|
||||
# TODO: remove load_semantic
|
||||
if self.load_semantic and 'semantic' in self.data_fields:
|
||||
seg_path = os.path.join(self.dataset_dir, 'stuffthingmaps',
|
||||
'train2017', im_fname[:-3] + 'png')
|
||||
coco_rec.update({'semantic': seg_path})
|
||||
|
||||
logger.debug('Load file: {}, im_id: {}, h: {}, w: {}.'.format(
|
||||
im_path, img_id, im_h, im_w))
|
||||
if is_empty:
|
||||
empty_records.append(coco_rec)
|
||||
else:
|
||||
records.append(coco_rec)
|
||||
ct += 1
|
||||
if self.sample_num > 0 and ct >= self.sample_num:
|
||||
break
|
||||
assert ct > 0, 'not found any coco record in %s' % (anno_path)
|
||||
logger.info('Load [{} samples valid, {} samples invalid] in file {}.'.
|
||||
format(ct, len(img_ids) - ct, anno_path))
|
||||
if self.allow_empty and len(empty_records) > 0:
|
||||
empty_records = self._sample_empty(empty_records, len(records))
|
||||
records += empty_records
|
||||
self.roidbs = records
|
||||
|
||||
|
||||
@register
|
||||
@serializable
|
||||
class SlicedCOCODataSet(COCODataSet):
|
||||
"""Sliced COCODataSet"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_dir=None,
|
||||
image_dir=None,
|
||||
anno_path=None,
|
||||
data_fields=['image'],
|
||||
sample_num=-1,
|
||||
load_crowd=False,
|
||||
allow_empty=False,
|
||||
empty_ratio=1.,
|
||||
repeat=1,
|
||||
sliced_size=[640, 640],
|
||||
overlap_ratio=[0.25, 0.25], ):
|
||||
super(SlicedCOCODataSet, self).__init__(
|
||||
dataset_dir=dataset_dir,
|
||||
image_dir=image_dir,
|
||||
anno_path=anno_path,
|
||||
data_fields=data_fields,
|
||||
sample_num=sample_num,
|
||||
load_crowd=load_crowd,
|
||||
allow_empty=allow_empty,
|
||||
empty_ratio=empty_ratio,
|
||||
repeat=repeat, )
|
||||
self.sliced_size = sliced_size
|
||||
self.overlap_ratio = overlap_ratio
|
||||
|
||||
def parse_dataset(self):
|
||||
anno_path = os.path.join(self.dataset_dir, self.anno_path)
|
||||
image_dir = os.path.join(self.dataset_dir, self.image_dir)
|
||||
|
||||
assert anno_path.endswith('.json'), \
|
||||
'invalid coco annotation file: ' + anno_path
|
||||
from pycocotools.coco import COCO
|
||||
coco = COCO(anno_path)
|
||||
img_ids = coco.getImgIds()
|
||||
img_ids.sort()
|
||||
cat_ids = coco.getCatIds()
|
||||
records = []
|
||||
empty_records = []
|
||||
ct = 0
|
||||
ct_sub = 0
|
||||
|
||||
self.catid2clsid = dict({catid: i for i, catid in enumerate(cat_ids)})
|
||||
self.cname2cid = dict({
|
||||
coco.loadCats(catid)[0]['name']: clsid
|
||||
for catid, clsid in self.catid2clsid.items()
|
||||
})
|
||||
|
||||
if 'annotations' not in coco.dataset:
|
||||
self.load_image_only = True
|
||||
logger.warning('Annotation file: {} does not contains ground truth '
|
||||
'and load image information only.'.format(anno_path))
|
||||
try:
|
||||
import sahi
|
||||
from sahi.slicing import slice_image
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'sahi not found, plaese install sahi. '
|
||||
'for example: `pip install sahi`, see https://github.com/obss/sahi.'
|
||||
)
|
||||
raise e
|
||||
|
||||
sub_img_ids = 0
|
||||
for img_id in img_ids:
|
||||
img_anno = coco.loadImgs([img_id])[0]
|
||||
im_fname = img_anno['file_name']
|
||||
im_w = float(img_anno['width'])
|
||||
im_h = float(img_anno['height'])
|
||||
|
||||
im_path = os.path.join(image_dir,
|
||||
im_fname) if image_dir else im_fname
|
||||
is_empty = False
|
||||
if not os.path.exists(im_path):
|
||||
logger.warning('Illegal image file: {}, and it will be '
|
||||
'ignored'.format(im_path))
|
||||
continue
|
||||
|
||||
if im_w < 0 or im_h < 0:
|
||||
logger.warning('Illegal width: {} or height: {} in annotation, '
|
||||
'and im_id: {} will be ignored'.format(
|
||||
im_w, im_h, img_id))
|
||||
continue
|
||||
|
||||
slice_image_result = sahi.slicing.slice_image(
|
||||
image=im_path,
|
||||
slice_height=self.sliced_size[0],
|
||||
slice_width=self.sliced_size[1],
|
||||
overlap_height_ratio=self.overlap_ratio[0],
|
||||
overlap_width_ratio=self.overlap_ratio[1])
|
||||
|
||||
sub_img_num = len(slice_image_result)
|
||||
for _ind in range(sub_img_num):
|
||||
im = slice_image_result.images[_ind]
|
||||
coco_rec = {
|
||||
'image': im,
|
||||
'im_id': np.array([sub_img_ids + _ind]),
|
||||
'h': im.shape[0],
|
||||
'w': im.shape[1],
|
||||
'ori_im_id': np.array([img_id]),
|
||||
'st_pix': np.array(
|
||||
slice_image_result.starting_pixels[_ind],
|
||||
dtype=np.float32),
|
||||
'is_last': 1 if _ind == sub_img_num - 1 else 0,
|
||||
} if 'image' in self.data_fields else {}
|
||||
records.append(coco_rec)
|
||||
ct_sub += sub_img_num
|
||||
ct += 1
|
||||
if self.sample_num > 0 and ct >= self.sample_num:
|
||||
break
|
||||
assert ct > 0, 'not found any coco record in %s' % (anno_path)
|
||||
logger.info('{} samples and slice to {} sub_samples in file {}'.format(
|
||||
ct, ct_sub, anno_path))
|
||||
if self.allow_empty and len(empty_records) > 0:
|
||||
empty_records = self._sample_empty(empty_records, len(records))
|
||||
records += empty_records
|
||||
self.roidbs = records
|
||||
|
||||
|
||||
@register
|
||||
@serializable
|
||||
class SemiCOCODataSet(COCODataSet):
|
||||
"""Semi-COCODataSet used for supervised and unsupervised dataSet"""
|
||||
|
||||
def __init__(self,
|
||||
dataset_dir=None,
|
||||
image_dir=None,
|
||||
anno_path=None,
|
||||
data_fields=['image'],
|
||||
sample_num=-1,
|
||||
load_crowd=False,
|
||||
allow_empty=False,
|
||||
empty_ratio=1.,
|
||||
repeat=1,
|
||||
supervised=True):
|
||||
super(SemiCOCODataSet, self).__init__(
|
||||
dataset_dir, image_dir, anno_path, data_fields, sample_num,
|
||||
load_crowd, allow_empty, empty_ratio, repeat)
|
||||
self.supervised = supervised
|
||||
self.length = -1 # defalut -1 means all
|
||||
|
||||
def parse_dataset(self):
|
||||
anno_path = os.path.join(self.dataset_dir, self.anno_path)
|
||||
image_dir = os.path.join(self.dataset_dir, self.image_dir)
|
||||
|
||||
assert anno_path.endswith('.json'), \
|
||||
'invalid coco annotation file: ' + anno_path
|
||||
from pycocotools.coco import COCO
|
||||
coco = COCO(anno_path)
|
||||
img_ids = coco.getImgIds()
|
||||
img_ids.sort()
|
||||
cat_ids = coco.getCatIds()
|
||||
records = []
|
||||
empty_records = []
|
||||
ct = 0
|
||||
|
||||
self.catid2clsid = dict({catid: i for i, catid in enumerate(cat_ids)})
|
||||
self.cname2cid = dict({
|
||||
coco.loadCats(catid)[0]['name']: clsid
|
||||
for catid, clsid in self.catid2clsid.items()
|
||||
})
|
||||
|
||||
if 'annotations' not in coco.dataset or self.supervised == False:
|
||||
self.load_image_only = True
|
||||
logger.warning('Annotation file: {} does not contains ground truth '
|
||||
'and load image information only.'.format(anno_path))
|
||||
|
||||
for img_id in img_ids:
|
||||
img_anno = coco.loadImgs([img_id])[0]
|
||||
im_fname = img_anno['file_name']
|
||||
im_w = float(img_anno['width'])
|
||||
im_h = float(img_anno['height'])
|
||||
|
||||
im_path = os.path.join(image_dir,
|
||||
im_fname) if image_dir else im_fname
|
||||
is_empty = False
|
||||
if not os.path.exists(im_path):
|
||||
logger.warning('Illegal image file: {}, and it will be '
|
||||
'ignored'.format(im_path))
|
||||
continue
|
||||
|
||||
if im_w < 0 or im_h < 0:
|
||||
logger.warning('Illegal width: {} or height: {} in annotation, '
|
||||
'and im_id: {} will be ignored'.format(
|
||||
im_w, im_h, img_id))
|
||||
continue
|
||||
|
||||
coco_rec = {
|
||||
'im_file': im_path,
|
||||
'im_id': np.array([img_id]),
|
||||
'h': im_h,
|
||||
'w': im_w,
|
||||
} if 'image' in self.data_fields else {}
|
||||
|
||||
if not self.load_image_only:
|
||||
ins_anno_ids = coco.getAnnIds(
|
||||
imgIds=[img_id], iscrowd=None if self.load_crowd else False)
|
||||
instances = coco.loadAnns(ins_anno_ids)
|
||||
|
||||
bboxes = []
|
||||
is_rbox_anno = False
|
||||
for inst in instances:
|
||||
# check gt bbox
|
||||
if inst.get('ignore', False):
|
||||
continue
|
||||
if 'bbox' not in inst.keys():
|
||||
continue
|
||||
else:
|
||||
if not any(np.array(inst['bbox'])):
|
||||
continue
|
||||
|
||||
x1, y1, box_w, box_h = inst['bbox']
|
||||
x2 = x1 + box_w
|
||||
y2 = y1 + box_h
|
||||
eps = 1e-5
|
||||
if inst['area'] > 0 and x2 - x1 > eps and y2 - y1 > eps:
|
||||
inst['clean_bbox'] = [
|
||||
round(float(x), 3) for x in [x1, y1, x2, y2]
|
||||
]
|
||||
bboxes.append(inst)
|
||||
else:
|
||||
logger.warning(
|
||||
'Found an invalid bbox in annotations: im_id: {}, '
|
||||
'area: {} x1: {}, y1: {}, x2: {}, y2: {}.'.format(
|
||||
img_id, float(inst['area']), x1, y1, x2, y2))
|
||||
|
||||
num_bbox = len(bboxes)
|
||||
if num_bbox <= 0 and not self.allow_empty:
|
||||
continue
|
||||
elif num_bbox <= 0:
|
||||
is_empty = True
|
||||
|
||||
gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32)
|
||||
gt_class = np.zeros((num_bbox, 1), dtype=np.int32)
|
||||
is_crowd = np.zeros((num_bbox, 1), dtype=np.int32)
|
||||
gt_poly = [None] * num_bbox
|
||||
|
||||
has_segmentation = False
|
||||
for i, box in enumerate(bboxes):
|
||||
catid = box['category_id']
|
||||
gt_class[i][0] = self.catid2clsid[catid]
|
||||
gt_bbox[i, :] = box['clean_bbox']
|
||||
is_crowd[i][0] = box['iscrowd']
|
||||
# check RLE format
|
||||
if 'segmentation' in box and box['iscrowd'] == 1:
|
||||
gt_poly[i] = [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]
|
||||
elif 'segmentation' in box and box['segmentation']:
|
||||
if not np.array(box['segmentation']
|
||||
).size > 0 and not self.allow_empty:
|
||||
bboxes.pop(i)
|
||||
gt_poly.pop(i)
|
||||
np.delete(is_crowd, i)
|
||||
np.delete(gt_class, i)
|
||||
np.delete(gt_bbox, i)
|
||||
else:
|
||||
gt_poly[i] = box['segmentation']
|
||||
has_segmentation = True
|
||||
|
||||
if has_segmentation and not any(
|
||||
gt_poly) and not self.allow_empty:
|
||||
continue
|
||||
|
||||
gt_rec = {
|
||||
'is_crowd': is_crowd,
|
||||
'gt_class': gt_class,
|
||||
'gt_bbox': gt_bbox,
|
||||
'gt_poly': gt_poly,
|
||||
}
|
||||
|
||||
for k, v in gt_rec.items():
|
||||
if k in self.data_fields:
|
||||
coco_rec[k] = v
|
||||
|
||||
# TODO: remove load_semantic
|
||||
if self.load_semantic and 'semantic' in self.data_fields:
|
||||
seg_path = os.path.join(self.dataset_dir, 'stuffthingmaps',
|
||||
'train2017', im_fname[:-3] + 'png')
|
||||
coco_rec.update({'semantic': seg_path})
|
||||
|
||||
logger.debug('Load file: {}, im_id: {}, h: {}, w: {}.'.format(
|
||||
im_path, img_id, im_h, im_w))
|
||||
if is_empty:
|
||||
empty_records.append(coco_rec)
|
||||
else:
|
||||
records.append(coco_rec)
|
||||
ct += 1
|
||||
if self.sample_num > 0 and ct >= self.sample_num:
|
||||
break
|
||||
assert ct > 0, 'not found any coco record in %s' % (anno_path)
|
||||
logger.info('Load [{} samples valid, {} samples invalid] in file {}.'.
|
||||
format(ct, len(img_ids) - ct, anno_path))
|
||||
if self.allow_empty and len(empty_records) > 0:
|
||||
empty_records = self._sample_empty(empty_records, len(records))
|
||||
records += empty_records
|
||||
self.roidbs = records
|
||||
|
||||
if self.supervised:
|
||||
logger.info(f'Use {len(self.roidbs)} sup_samples data as LABELED')
|
||||
else:
|
||||
if self.length > 0: # unsup length will be decide by sup length
|
||||
all_roidbs = self.roidbs.copy()
|
||||
selected_idxs = [
|
||||
np.random.choice(len(all_roidbs))
|
||||
for _ in range(self.length)
|
||||
]
|
||||
self.roidbs = [all_roidbs[i] for i in selected_idxs]
|
||||
logger.info(
|
||||
f'Use {len(self.roidbs)} unsup_samples data as UNLABELED')
|
||||
|
||||
def __getitem__(self, idx):
|
||||
n = len(self.roidbs)
|
||||
if self.repeat > 1:
|
||||
idx %= n
|
||||
# data batch
|
||||
roidb = copy.deepcopy(self.roidbs[idx])
|
||||
if self.mixup_epoch == 0 or self._epoch < self.mixup_epoch:
|
||||
idx = np.random.randint(n)
|
||||
roidb = [roidb, copy.deepcopy(self.roidbs[idx])]
|
||||
elif self.cutmix_epoch == 0 or self._epoch < self.cutmix_epoch:
|
||||
idx = np.random.randint(n)
|
||||
roidb = [roidb, copy.deepcopy(self.roidbs[idx])]
|
||||
elif self.mosaic_epoch == 0 or self._epoch < self.mosaic_epoch:
|
||||
roidb = [roidb, ] + [
|
||||
copy.deepcopy(self.roidbs[np.random.randint(n)])
|
||||
for _ in range(4)
|
||||
]
|
||||
if isinstance(roidb, Sequence):
|
||||
for r in roidb:
|
||||
r['curr_iter'] = self._curr_iter
|
||||
else:
|
||||
roidb['curr_iter'] = self._curr_iter
|
||||
self._curr_iter += 1
|
||||
|
||||
return self.transform(roidb)
|
||||
307
rtdetr_paddle/ppdet/data/source/dataset.py
Normal file
307
rtdetr_paddle/ppdet/data/source/dataset.py
Normal file
@@ -0,0 +1,307 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import copy
|
||||
import numpy as np
|
||||
try:
|
||||
from collections.abc import Sequence
|
||||
except Exception:
|
||||
from collections import Sequence
|
||||
from paddle.io import Dataset
|
||||
from ppdet.core.workspace import register, serializable
|
||||
from ppdet.utils.download import get_dataset_path
|
||||
from ppdet.data import source
|
||||
|
||||
from ppdet.utils.logger import setup_logger
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
@serializable
|
||||
class DetDataset(Dataset):
|
||||
"""
|
||||
Load detection dataset.
|
||||
|
||||
Args:
|
||||
dataset_dir (str): root directory for dataset.
|
||||
image_dir (str): directory for images.
|
||||
anno_path (str): annotation file path.
|
||||
data_fields (list): key name of data dictionary, at least have 'image'.
|
||||
sample_num (int): number of samples to load, -1 means all.
|
||||
use_default_label (bool): whether to load default label list.
|
||||
repeat (int): repeat times for dataset, use in benchmark.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dataset_dir=None,
|
||||
image_dir=None,
|
||||
anno_path=None,
|
||||
data_fields=['image'],
|
||||
sample_num=-1,
|
||||
use_default_label=None,
|
||||
repeat=1,
|
||||
**kwargs):
|
||||
super(DetDataset, self).__init__()
|
||||
self.dataset_dir = dataset_dir if dataset_dir is not None else ''
|
||||
self.anno_path = anno_path
|
||||
self.image_dir = image_dir if image_dir is not None else ''
|
||||
self.data_fields = data_fields
|
||||
self.sample_num = sample_num
|
||||
self.use_default_label = use_default_label
|
||||
self.repeat = repeat
|
||||
self._epoch = 0
|
||||
self._curr_iter = 0
|
||||
|
||||
def __len__(self, ):
|
||||
return len(self.roidbs) * self.repeat
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def __getitem__(self, idx):
|
||||
n = len(self.roidbs)
|
||||
if self.repeat > 1:
|
||||
idx %= n
|
||||
# data batch
|
||||
roidb = copy.deepcopy(self.roidbs[idx])
|
||||
if self.mixup_epoch == 0 or self._epoch < self.mixup_epoch:
|
||||
idx = np.random.randint(n)
|
||||
roidb = [roidb, copy.deepcopy(self.roidbs[idx])]
|
||||
elif self.cutmix_epoch == 0 or self._epoch < self.cutmix_epoch:
|
||||
idx = np.random.randint(n)
|
||||
roidb = [roidb, copy.deepcopy(self.roidbs[idx])]
|
||||
elif self.mosaic_epoch == 0 or self._epoch < self.mosaic_epoch:
|
||||
roidb = [roidb, ] + [
|
||||
copy.deepcopy(self.roidbs[np.random.randint(n)])
|
||||
for _ in range(4)
|
||||
]
|
||||
elif self.pre_img_epoch == 0 or self._epoch < self.pre_img_epoch:
|
||||
# Add previous image as input, only used in CenterTrack
|
||||
idx_pre_img = idx - 1
|
||||
if idx_pre_img < 0:
|
||||
idx_pre_img = idx + 1
|
||||
roidb = [roidb, ] + [copy.deepcopy(self.roidbs[idx_pre_img])]
|
||||
if isinstance(roidb, Sequence):
|
||||
for r in roidb:
|
||||
r['curr_iter'] = self._curr_iter
|
||||
else:
|
||||
roidb['curr_iter'] = self._curr_iter
|
||||
self._curr_iter += 1
|
||||
|
||||
return self.transform(roidb)
|
||||
|
||||
def check_or_download_dataset(self):
|
||||
self.dataset_dir = get_dataset_path(self.dataset_dir, self.anno_path,
|
||||
self.image_dir)
|
||||
|
||||
def set_kwargs(self, **kwargs):
|
||||
self.mixup_epoch = kwargs.get('mixup_epoch', -1)
|
||||
self.cutmix_epoch = kwargs.get('cutmix_epoch', -1)
|
||||
self.mosaic_epoch = kwargs.get('mosaic_epoch', -1)
|
||||
self.pre_img_epoch = kwargs.get('pre_img_epoch', -1)
|
||||
|
||||
def set_transform(self, transform):
|
||||
self.transform = transform
|
||||
|
||||
def set_epoch(self, epoch_id):
|
||||
self._epoch = epoch_id
|
||||
|
||||
def parse_dataset(self, ):
|
||||
raise NotImplementedError(
|
||||
"Need to implement parse_dataset method of Dataset")
|
||||
|
||||
def get_anno(self):
|
||||
if self.anno_path is None:
|
||||
return
|
||||
return os.path.join(self.dataset_dir, self.anno_path)
|
||||
|
||||
|
||||
def _is_valid_file(f, extensions=('.jpg', '.jpeg', '.png', '.bmp')):
|
||||
return f.lower().endswith(extensions)
|
||||
|
||||
|
||||
def _make_dataset(dir):
|
||||
dir = os.path.expanduser(dir)
|
||||
if not os.path.isdir(dir):
|
||||
raise ('{} should be a dir'.format(dir))
|
||||
images = []
|
||||
for root, _, fnames in sorted(os.walk(dir, followlinks=True)):
|
||||
for fname in sorted(fnames):
|
||||
path = os.path.join(root, fname)
|
||||
if _is_valid_file(path):
|
||||
images.append(path)
|
||||
return images
|
||||
|
||||
|
||||
@register
|
||||
@serializable
|
||||
class ImageFolder(DetDataset):
|
||||
def __init__(self,
|
||||
dataset_dir=None,
|
||||
image_dir=None,
|
||||
anno_path=None,
|
||||
sample_num=-1,
|
||||
use_default_label=None,
|
||||
**kwargs):
|
||||
super(ImageFolder, self).__init__(
|
||||
dataset_dir,
|
||||
image_dir,
|
||||
anno_path,
|
||||
sample_num=sample_num,
|
||||
use_default_label=use_default_label)
|
||||
self._imid2path = {}
|
||||
self.roidbs = None
|
||||
self.sample_num = sample_num
|
||||
|
||||
def check_or_download_dataset(self):
|
||||
return
|
||||
|
||||
def get_anno(self):
|
||||
if self.anno_path is None:
|
||||
return
|
||||
if self.dataset_dir:
|
||||
return os.path.join(self.dataset_dir, self.anno_path)
|
||||
else:
|
||||
return self.anno_path
|
||||
|
||||
def parse_dataset(self, ):
|
||||
if not self.roidbs:
|
||||
self.roidbs = self._load_images()
|
||||
|
||||
def _parse(self):
|
||||
image_dir = self.image_dir
|
||||
if not isinstance(image_dir, Sequence):
|
||||
image_dir = [image_dir]
|
||||
images = []
|
||||
for im_dir in image_dir:
|
||||
if os.path.isdir(im_dir):
|
||||
im_dir = os.path.join(self.dataset_dir, im_dir)
|
||||
images.extend(_make_dataset(im_dir))
|
||||
elif os.path.isfile(im_dir) and _is_valid_file(im_dir):
|
||||
images.append(im_dir)
|
||||
return images
|
||||
|
||||
def _load_images(self):
|
||||
images = self._parse()
|
||||
ct = 0
|
||||
records = []
|
||||
for image in images:
|
||||
assert image != '' and os.path.isfile(image), \
|
||||
"Image {} not found".format(image)
|
||||
if self.sample_num > 0 and ct >= self.sample_num:
|
||||
break
|
||||
rec = {'im_id': np.array([ct]), 'im_file': image}
|
||||
self._imid2path[ct] = image
|
||||
ct += 1
|
||||
records.append(rec)
|
||||
assert len(records) > 0, "No image file found"
|
||||
return records
|
||||
|
||||
def get_imid2path(self):
|
||||
return self._imid2path
|
||||
|
||||
def set_images(self, images):
|
||||
self.image_dir = images
|
||||
self.roidbs = self._load_images()
|
||||
|
||||
def set_slice_images(self,
|
||||
images,
|
||||
slice_size=[640, 640],
|
||||
overlap_ratio=[0.25, 0.25]):
|
||||
self.image_dir = images
|
||||
ori_records = self._load_images()
|
||||
try:
|
||||
import sahi
|
||||
from sahi.slicing import slice_image
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'sahi not found, plaese install sahi. '
|
||||
'for example: `pip install sahi`, see https://github.com/obss/sahi.'
|
||||
)
|
||||
raise e
|
||||
|
||||
sub_img_ids = 0
|
||||
ct = 0
|
||||
ct_sub = 0
|
||||
records = []
|
||||
for i, ori_rec in enumerate(ori_records):
|
||||
im_path = ori_rec['im_file']
|
||||
slice_image_result = sahi.slicing.slice_image(
|
||||
image=im_path,
|
||||
slice_height=slice_size[0],
|
||||
slice_width=slice_size[1],
|
||||
overlap_height_ratio=overlap_ratio[0],
|
||||
overlap_width_ratio=overlap_ratio[1])
|
||||
|
||||
sub_img_num = len(slice_image_result)
|
||||
for _ind in range(sub_img_num):
|
||||
im = slice_image_result.images[_ind]
|
||||
rec = {
|
||||
'image': im,
|
||||
'im_id': np.array([sub_img_ids + _ind]),
|
||||
'h': im.shape[0],
|
||||
'w': im.shape[1],
|
||||
'ori_im_id': np.array([ori_rec['im_id'][0]]),
|
||||
'st_pix': np.array(
|
||||
slice_image_result.starting_pixels[_ind],
|
||||
dtype=np.float32),
|
||||
'is_last': 1 if _ind == sub_img_num - 1 else 0,
|
||||
} if 'image' in self.data_fields else {}
|
||||
records.append(rec)
|
||||
ct_sub += sub_img_num
|
||||
ct += 1
|
||||
logger.info('{} samples and slice to {} sub_samples.'.format(ct,
|
||||
ct_sub))
|
||||
self.roidbs = records
|
||||
|
||||
def get_label_list(self):
|
||||
# Only VOC dataset needs label list in ImageFold
|
||||
return self.anno_path
|
||||
|
||||
|
||||
@register
|
||||
class CommonDataset(object):
|
||||
def __init__(self, **dataset_args):
|
||||
super(CommonDataset, self).__init__()
|
||||
dataset_args = copy.deepcopy(dataset_args)
|
||||
type = dataset_args.pop("name")
|
||||
self.dataset = getattr(source, type)(**dataset_args)
|
||||
|
||||
def __call__(self):
|
||||
return self.dataset
|
||||
|
||||
|
||||
@register
|
||||
class TrainDataset(CommonDataset):
|
||||
pass
|
||||
|
||||
|
||||
@register
|
||||
class EvalMOTDataset(CommonDataset):
|
||||
pass
|
||||
|
||||
|
||||
@register
|
||||
class TestMOTDataset(CommonDataset):
|
||||
pass
|
||||
|
||||
|
||||
@register
|
||||
class EvalDataset(CommonDataset):
|
||||
pass
|
||||
|
||||
|
||||
@register
|
||||
class TestDataset(CommonDataset):
|
||||
pass
|
||||
234
rtdetr_paddle/ppdet/data/source/voc.py
Normal file
234
rtdetr_paddle/ppdet/data/source/voc.py
Normal file
@@ -0,0 +1,234 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
from ppdet.core.workspace import register, serializable
|
||||
|
||||
from .dataset import DetDataset
|
||||
|
||||
from ppdet.utils.logger import setup_logger
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
@register
|
||||
@serializable
|
||||
class VOCDataSet(DetDataset):
|
||||
"""
|
||||
Load dataset with PascalVOC format.
|
||||
|
||||
Notes:
|
||||
`anno_path` must contains xml file and image file path for annotations.
|
||||
|
||||
Args:
|
||||
dataset_dir (str): root directory for dataset.
|
||||
image_dir (str): directory for images.
|
||||
anno_path (str): voc annotation file path.
|
||||
data_fields (list): key name of data dictionary, at least have 'image'.
|
||||
sample_num (int): number of samples to load, -1 means all.
|
||||
label_list (str): if use_default_label is False, will load
|
||||
mapping between category and class index.
|
||||
allow_empty (bool): whether to load empty entry. False as default
|
||||
empty_ratio (float): the ratio of empty record number to total
|
||||
record's, if empty_ratio is out of [0. ,1.), do not sample the
|
||||
records and use all the empty entries. 1. as default
|
||||
repeat (int): repeat times for dataset, use in benchmark.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dataset_dir=None,
|
||||
image_dir=None,
|
||||
anno_path=None,
|
||||
data_fields=['image'],
|
||||
sample_num=-1,
|
||||
label_list=None,
|
||||
allow_empty=False,
|
||||
empty_ratio=1.,
|
||||
repeat=1):
|
||||
super(VOCDataSet, self).__init__(
|
||||
dataset_dir=dataset_dir,
|
||||
image_dir=image_dir,
|
||||
anno_path=anno_path,
|
||||
data_fields=data_fields,
|
||||
sample_num=sample_num,
|
||||
repeat=repeat)
|
||||
self.label_list = label_list
|
||||
self.allow_empty = allow_empty
|
||||
self.empty_ratio = empty_ratio
|
||||
|
||||
def _sample_empty(self, records, num):
|
||||
# if empty_ratio is out of [0. ,1.), do not sample the records
|
||||
if self.empty_ratio < 0. or self.empty_ratio >= 1.:
|
||||
return records
|
||||
import random
|
||||
sample_num = min(
|
||||
int(num * self.empty_ratio / (1 - self.empty_ratio)), len(records))
|
||||
records = random.sample(records, sample_num)
|
||||
return records
|
||||
|
||||
def parse_dataset(self, ):
|
||||
anno_path = os.path.join(self.dataset_dir, self.anno_path)
|
||||
image_dir = os.path.join(self.dataset_dir, self.image_dir)
|
||||
|
||||
# mapping category name to class id
|
||||
# first_class:0, second_class:1, ...
|
||||
records = []
|
||||
empty_records = []
|
||||
ct = 0
|
||||
cname2cid = {}
|
||||
if self.label_list:
|
||||
label_path = os.path.join(self.dataset_dir, self.label_list)
|
||||
if not os.path.exists(label_path):
|
||||
raise ValueError("label_list {} does not exists".format(
|
||||
label_path))
|
||||
with open(label_path, 'r') as fr:
|
||||
label_id = 0
|
||||
for line in fr.readlines():
|
||||
cname2cid[line.strip()] = label_id
|
||||
label_id += 1
|
||||
else:
|
||||
cname2cid = pascalvoc_label()
|
||||
|
||||
with open(anno_path, 'r') as fr:
|
||||
while True:
|
||||
line = fr.readline()
|
||||
if not line:
|
||||
break
|
||||
img_file, xml_file = [os.path.join(image_dir, x) \
|
||||
for x in line.strip().split()[:2]]
|
||||
if not os.path.exists(img_file):
|
||||
logger.warning(
|
||||
'Illegal image file: {}, and it will be ignored'.format(
|
||||
img_file))
|
||||
continue
|
||||
if not os.path.isfile(xml_file):
|
||||
logger.warning(
|
||||
'Illegal xml file: {}, and it will be ignored'.format(
|
||||
xml_file))
|
||||
continue
|
||||
tree = ET.parse(xml_file)
|
||||
if tree.find('id') is None:
|
||||
im_id = np.array([ct])
|
||||
else:
|
||||
im_id = np.array([int(tree.find('id').text)])
|
||||
|
||||
objs = tree.findall('object')
|
||||
im_w = float(tree.find('size').find('width').text)
|
||||
im_h = float(tree.find('size').find('height').text)
|
||||
if im_w < 0 or im_h < 0:
|
||||
logger.warning(
|
||||
'Illegal width: {} or height: {} in annotation, '
|
||||
'and {} will be ignored'.format(im_w, im_h, xml_file))
|
||||
continue
|
||||
|
||||
num_bbox, i = len(objs), 0
|
||||
gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32)
|
||||
gt_class = np.zeros((num_bbox, 1), dtype=np.int32)
|
||||
gt_score = np.zeros((num_bbox, 1), dtype=np.float32)
|
||||
difficult = np.zeros((num_bbox, 1), dtype=np.int32)
|
||||
for obj in objs:
|
||||
cname = obj.find('name').text
|
||||
|
||||
# user dataset may not contain difficult field
|
||||
_difficult = obj.find('difficult')
|
||||
_difficult = int(
|
||||
_difficult.text) if _difficult is not None else 0
|
||||
|
||||
x1 = float(obj.find('bndbox').find('xmin').text)
|
||||
y1 = float(obj.find('bndbox').find('ymin').text)
|
||||
x2 = float(obj.find('bndbox').find('xmax').text)
|
||||
y2 = float(obj.find('bndbox').find('ymax').text)
|
||||
x1 = max(0, x1)
|
||||
y1 = max(0, y1)
|
||||
x2 = min(im_w - 1, x2)
|
||||
y2 = min(im_h - 1, y2)
|
||||
if x2 > x1 and y2 > y1:
|
||||
gt_bbox[i, :] = [x1, y1, x2, y2]
|
||||
gt_class[i, 0] = cname2cid[cname]
|
||||
gt_score[i, 0] = 1.
|
||||
difficult[i, 0] = _difficult
|
||||
i += 1
|
||||
else:
|
||||
logger.warning(
|
||||
'Found an invalid bbox in annotations: xml_file: {}'
|
||||
', x1: {}, y1: {}, x2: {}, y2: {}.'.format(
|
||||
xml_file, x1, y1, x2, y2))
|
||||
gt_bbox = gt_bbox[:i, :]
|
||||
gt_class = gt_class[:i, :]
|
||||
gt_score = gt_score[:i, :]
|
||||
difficult = difficult[:i, :]
|
||||
|
||||
voc_rec = {
|
||||
'im_file': img_file,
|
||||
'im_id': im_id,
|
||||
'h': im_h,
|
||||
'w': im_w
|
||||
} if 'image' in self.data_fields else {}
|
||||
|
||||
gt_rec = {
|
||||
'gt_class': gt_class,
|
||||
'gt_score': gt_score,
|
||||
'gt_bbox': gt_bbox,
|
||||
'difficult': difficult
|
||||
}
|
||||
for k, v in gt_rec.items():
|
||||
if k in self.data_fields:
|
||||
voc_rec[k] = v
|
||||
|
||||
if len(objs) == 0:
|
||||
empty_records.append(voc_rec)
|
||||
else:
|
||||
records.append(voc_rec)
|
||||
|
||||
ct += 1
|
||||
if self.sample_num > 0 and ct >= self.sample_num:
|
||||
break
|
||||
assert ct > 0, 'not found any voc record in %s' % (self.anno_path)
|
||||
logger.debug('{} samples in file {}'.format(ct, anno_path))
|
||||
if self.allow_empty and len(empty_records) > 0:
|
||||
empty_records = self._sample_empty(empty_records, len(records))
|
||||
records += empty_records
|
||||
self.roidbs, self.cname2cid = records, cname2cid
|
||||
|
||||
def get_label_list(self):
|
||||
return os.path.join(self.dataset_dir, self.label_list)
|
||||
|
||||
|
||||
def pascalvoc_label():
|
||||
labels_map = {
|
||||
'aeroplane': 0,
|
||||
'bicycle': 1,
|
||||
'bird': 2,
|
||||
'boat': 3,
|
||||
'bottle': 4,
|
||||
'bus': 5,
|
||||
'car': 6,
|
||||
'cat': 7,
|
||||
'chair': 8,
|
||||
'cow': 9,
|
||||
'diningtable': 10,
|
||||
'dog': 11,
|
||||
'horse': 12,
|
||||
'motorbike': 13,
|
||||
'person': 14,
|
||||
'pottedplant': 15,
|
||||
'sheep': 16,
|
||||
'sofa': 17,
|
||||
'train': 18,
|
||||
'tvmonitor': 19
|
||||
}
|
||||
return labels_map
|
||||
Reference in New Issue
Block a user