first commit
This commit is contained in:
21
rtdetr_paddle/ppdet/data/__init__.py
Normal file
21
rtdetr_paddle/ppdet/data/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import source
|
||||
from . import transform
|
||||
from . import reader
|
||||
|
||||
from .source import *
|
||||
from .transform import *
|
||||
from .reader import *
|
||||
274
rtdetr_paddle/ppdet/data/reader.py
Normal file
274
rtdetr_paddle/ppdet/data/reader.py
Normal file
@@ -0,0 +1,274 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import os
|
||||
import traceback
|
||||
import six
|
||||
import sys
|
||||
if sys.version_info >= (3, 0):
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.nn.functional as F
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
from paddle.io import DataLoader, DistributedBatchSampler
|
||||
from .utils import default_collate_fn
|
||||
|
||||
from ppdet.core.workspace import register
|
||||
from . import transform
|
||||
from .shm_utils import _get_shared_memory_size_in_M
|
||||
|
||||
from ppdet.utils.logger import setup_logger
|
||||
logger = setup_logger('reader')
|
||||
|
||||
MAIN_PID = os.getpid()
|
||||
|
||||
|
||||
class Compose(object):
|
||||
def __init__(self, transforms, num_classes=80):
|
||||
self.transforms = transforms
|
||||
self.transforms_cls = []
|
||||
for t in self.transforms:
|
||||
for k, v in t.items():
|
||||
op_cls = getattr(transform, k)
|
||||
f = op_cls(**v)
|
||||
if hasattr(f, 'num_classes'):
|
||||
f.num_classes = num_classes
|
||||
|
||||
self.transforms_cls.append(f)
|
||||
|
||||
def __call__(self, data):
|
||||
for f in self.transforms_cls:
|
||||
try:
|
||||
data = f(data)
|
||||
except Exception as e:
|
||||
stack_info = traceback.format_exc()
|
||||
logger.warning("fail to map sample transform [{}] "
|
||||
"with error: {} and stack:\n{}".format(
|
||||
f, e, str(stack_info)))
|
||||
raise e
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class BatchCompose(Compose):
|
||||
def __init__(self, transforms, num_classes=80, collate_batch=True):
|
||||
super(BatchCompose, self).__init__(transforms, num_classes)
|
||||
self.collate_batch = collate_batch
|
||||
|
||||
def __call__(self, data):
|
||||
for f in self.transforms_cls:
|
||||
try:
|
||||
data = f(data)
|
||||
except Exception as e:
|
||||
stack_info = traceback.format_exc()
|
||||
logger.warning("fail to map batch transform [{}] "
|
||||
"with error: {} and stack:\n{}".format(
|
||||
f, e, str(stack_info)))
|
||||
raise e
|
||||
|
||||
# remove keys which is not needed by model
|
||||
extra_key = ['h', 'w', 'flipped']
|
||||
for k in extra_key:
|
||||
for sample in data:
|
||||
if k in sample:
|
||||
sample.pop(k)
|
||||
|
||||
# batch data, if user-define batch function needed
|
||||
# use user-defined here
|
||||
if self.collate_batch:
|
||||
batch_data = default_collate_fn(data)
|
||||
else:
|
||||
batch_data = {}
|
||||
for k in data[0].keys():
|
||||
tmp_data = []
|
||||
for i in range(len(data)):
|
||||
tmp_data.append(data[i][k])
|
||||
if not 'gt_' in k and not 'is_crowd' in k and not 'difficult' in k:
|
||||
tmp_data = np.stack(tmp_data, axis=0)
|
||||
batch_data[k] = tmp_data
|
||||
return batch_data
|
||||
|
||||
|
||||
class BaseDataLoader(object):
|
||||
"""
|
||||
Base DataLoader implementation for detection models
|
||||
|
||||
Args:
|
||||
sample_transforms (list): a list of transforms to perform
|
||||
on each sample
|
||||
batch_transforms (list): a list of transforms to perform
|
||||
on batch
|
||||
batch_size (int): batch size for batch collating, default 1.
|
||||
shuffle (bool): whether to shuffle samples
|
||||
drop_last (bool): whether to drop the last incomplete,
|
||||
default False
|
||||
num_classes (int): class number of dataset, default 80
|
||||
collate_batch (bool): whether to collate batch in dataloader.
|
||||
If set to True, the samples will collate into batch according
|
||||
to the batch size. Otherwise, the ground-truth will not collate,
|
||||
which is used when the number of ground-truch is different in
|
||||
samples.
|
||||
use_shared_memory (bool): whether to use shared memory to
|
||||
accelerate data loading, enable this only if you
|
||||
are sure that the shared memory size of your OS
|
||||
is larger than memory cost of input datas of model.
|
||||
Note that shared memory will be automatically
|
||||
disabled if the shared memory of OS is less than
|
||||
1G, which is not enough for detection models.
|
||||
Default False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
sample_transforms=[],
|
||||
batch_transforms=[],
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
num_classes=80,
|
||||
collate_batch=True,
|
||||
use_shared_memory=False,
|
||||
**kwargs):
|
||||
# sample transform
|
||||
self._sample_transforms = Compose(
|
||||
sample_transforms, num_classes=num_classes)
|
||||
|
||||
# batch transfrom
|
||||
self._batch_transforms = BatchCompose(batch_transforms, num_classes,
|
||||
collate_batch)
|
||||
self.batch_size = batch_size
|
||||
self.shuffle = shuffle
|
||||
self.drop_last = drop_last
|
||||
self.use_shared_memory = use_shared_memory
|
||||
self.kwargs = kwargs
|
||||
|
||||
def __call__(self,
|
||||
dataset,
|
||||
worker_num,
|
||||
batch_sampler=None,
|
||||
return_list=False):
|
||||
self.dataset = dataset
|
||||
self.dataset.check_or_download_dataset()
|
||||
self.dataset.parse_dataset()
|
||||
# get data
|
||||
self.dataset.set_transform(self._sample_transforms)
|
||||
# set kwargs
|
||||
self.dataset.set_kwargs(**self.kwargs)
|
||||
# batch sampler
|
||||
if batch_sampler is None:
|
||||
self._batch_sampler = DistributedBatchSampler(
|
||||
self.dataset,
|
||||
batch_size=self.batch_size,
|
||||
shuffle=self.shuffle,
|
||||
drop_last=self.drop_last)
|
||||
else:
|
||||
self._batch_sampler = batch_sampler
|
||||
|
||||
# DataLoader do not start sub-process in Windows and Mac
|
||||
# system, do not need to use shared memory
|
||||
use_shared_memory = self.use_shared_memory and \
|
||||
sys.platform not in ['win32', 'darwin']
|
||||
# check whether shared memory size is bigger than 1G(1024M)
|
||||
if use_shared_memory:
|
||||
shm_size = _get_shared_memory_size_in_M()
|
||||
if shm_size is not None and shm_size < 1024.:
|
||||
logger.warning("Shared memory size is less than 1G, "
|
||||
"disable shared_memory in DataLoader")
|
||||
use_shared_memory = False
|
||||
|
||||
self.dataloader = DataLoader(
|
||||
dataset=self.dataset,
|
||||
batch_sampler=self._batch_sampler,
|
||||
collate_fn=self._batch_transforms,
|
||||
num_workers=worker_num,
|
||||
return_list=return_list,
|
||||
use_shared_memory=use_shared_memory)
|
||||
self.loader = iter(self.dataloader)
|
||||
|
||||
return self
|
||||
|
||||
def __len__(self):
|
||||
return len(self._batch_sampler)
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
try:
|
||||
return next(self.loader)
|
||||
except StopIteration:
|
||||
self.loader = iter(self.dataloader)
|
||||
six.reraise(*sys.exc_info())
|
||||
|
||||
def next(self):
|
||||
# python2 compatibility
|
||||
return self.__next__()
|
||||
|
||||
|
||||
@register
|
||||
class TrainReader(BaseDataLoader):
|
||||
__shared__ = ['num_classes']
|
||||
|
||||
def __init__(self,
|
||||
sample_transforms=[],
|
||||
batch_transforms=[],
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
num_classes=80,
|
||||
collate_batch=True,
|
||||
**kwargs):
|
||||
super(TrainReader, self).__init__(sample_transforms, batch_transforms,
|
||||
batch_size, shuffle, drop_last,
|
||||
num_classes, collate_batch, **kwargs)
|
||||
|
||||
|
||||
@register
|
||||
class EvalReader(BaseDataLoader):
|
||||
__shared__ = ['num_classes']
|
||||
|
||||
def __init__(self,
|
||||
sample_transforms=[],
|
||||
batch_transforms=[],
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
num_classes=80,
|
||||
**kwargs):
|
||||
super(EvalReader, self).__init__(sample_transforms, batch_transforms,
|
||||
batch_size, shuffle, drop_last,
|
||||
num_classes, **kwargs)
|
||||
|
||||
|
||||
@register
|
||||
class TestReader(BaseDataLoader):
|
||||
__shared__ = ['num_classes']
|
||||
|
||||
def __init__(self,
|
||||
sample_transforms=[],
|
||||
batch_transforms=[],
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
num_classes=80,
|
||||
**kwargs):
|
||||
super(TestReader, self).__init__(sample_transforms, batch_transforms,
|
||||
batch_size, shuffle, drop_last,
|
||||
num_classes, **kwargs)
|
||||
|
||||
70
rtdetr_paddle/ppdet/data/shm_utils.py
Normal file
70
rtdetr_paddle/ppdet/data/shm_utils.py
Normal file
@@ -0,0 +1,70 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
SIZE_UNIT = ['K', 'M', 'G', 'T']
|
||||
SHM_QUERY_CMD = 'df -h'
|
||||
SHM_KEY = 'shm'
|
||||
SHM_DEFAULT_MOUNT = '/dev/shm'
|
||||
|
||||
# [ shared memory size check ]
|
||||
# In detection models, image/target data occupies a lot of memory, and
|
||||
# will occupy lots of shared memory in multi-process DataLoader, we use
|
||||
# following code to get shared memory size and perform a size check to
|
||||
# disable shared memory use if shared memory size is not enough.
|
||||
# Shared memory getting process as follows:
|
||||
# 1. use `df -h` get all mount info
|
||||
# 2. pick up spaces whose mount info contains 'shm'
|
||||
# 3. if 'shm' space number is only 1, return its size
|
||||
# 4. if there are multiple 'shm' space, try to find the default mount
|
||||
# directory '/dev/shm' is Linux-like system, otherwise return the
|
||||
# biggest space size.
|
||||
|
||||
|
||||
def _parse_size_in_M(size_str):
|
||||
if size_str[-1] == 'B':
|
||||
num, unit = size_str[:-2], size_str[-2]
|
||||
else:
|
||||
num, unit = size_str[:-1], size_str[-1]
|
||||
assert unit in SIZE_UNIT, \
|
||||
"unknown shm size unit {}".format(unit)
|
||||
return float(num) * \
|
||||
(1024 ** (SIZE_UNIT.index(unit) - 1))
|
||||
|
||||
|
||||
def _get_shared_memory_size_in_M():
|
||||
try:
|
||||
df_infos = os.popen(SHM_QUERY_CMD).readlines()
|
||||
except:
|
||||
return None
|
||||
else:
|
||||
shm_infos = []
|
||||
for df_info in df_infos:
|
||||
info = df_info.strip()
|
||||
if info.find(SHM_KEY) >= 0:
|
||||
shm_infos.append(info.split())
|
||||
|
||||
if len(shm_infos) == 0:
|
||||
return None
|
||||
elif len(shm_infos) == 1:
|
||||
return _parse_size_in_M(shm_infos[0][3])
|
||||
else:
|
||||
default_mount_infos = [
|
||||
si for si in shm_infos if si[-1] == SHM_DEFAULT_MOUNT
|
||||
]
|
||||
if default_mount_infos:
|
||||
return _parse_size_in_M(default_mount_infos[0][3])
|
||||
else:
|
||||
return max([_parse_size_in_M(si[3]) for si in shm_infos])
|
||||
18
rtdetr_paddle/ppdet/data/source/__init__.py
Normal file
18
rtdetr_paddle/ppdet/data/source/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .coco import *
|
||||
from .voc import *
|
||||
from .category import *
|
||||
from .dataset import ImageFolder
|
||||
926
rtdetr_paddle/ppdet/data/source/category.py
Normal file
926
rtdetr_paddle/ppdet/data/source/category.py
Normal file
@@ -0,0 +1,926 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from ppdet.data.source.voc import pascalvoc_label
|
||||
from ppdet.utils.logger import setup_logger
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
__all__ = ['get_categories']
|
||||
|
||||
|
||||
def get_categories(metric_type, anno_file=None, arch=None):
|
||||
"""
|
||||
Get class id to category id map and category id
|
||||
to category name map from annotation file.
|
||||
|
||||
Args:
|
||||
metric_type (str): metric type, currently support 'coco', 'voc', 'oid'
|
||||
and 'widerface'.
|
||||
anno_file (str): annotation file path
|
||||
"""
|
||||
if arch == 'keypoint_arch':
|
||||
return (None, {'id': 'keypoint'})
|
||||
|
||||
if anno_file == None or (not os.path.isfile(anno_file)):
|
||||
logger.warning(
|
||||
"anno_file '{}' is None or not set or not exist, "
|
||||
"please recheck TrainDataset/EvalDataset/TestDataset.anno_path, "
|
||||
"otherwise the default categories will be used by metric_type.".
|
||||
format(anno_file))
|
||||
|
||||
if metric_type.lower() == 'coco' or metric_type.lower(
|
||||
) == 'rbox' or metric_type.lower() == 'snipercoco':
|
||||
if anno_file and os.path.isfile(anno_file):
|
||||
if anno_file.endswith('json'):
|
||||
# lazy import pycocotools here
|
||||
from pycocotools.coco import COCO
|
||||
coco = COCO(anno_file)
|
||||
cats = coco.loadCats(coco.getCatIds())
|
||||
|
||||
clsid2catid = {i: cat['id'] for i, cat in enumerate(cats)}
|
||||
catid2name = {cat['id']: cat['name'] for cat in cats}
|
||||
|
||||
elif anno_file.endswith('txt'):
|
||||
cats = []
|
||||
with open(anno_file) as f:
|
||||
for line in f.readlines():
|
||||
cats.append(line.strip())
|
||||
if cats[0] == 'background': cats = cats[1:]
|
||||
|
||||
clsid2catid = {i: i for i in range(len(cats))}
|
||||
catid2name = {i: name for i, name in enumerate(cats)}
|
||||
|
||||
else:
|
||||
raise ValueError("anno_file {} should be json or txt.".format(
|
||||
anno_file))
|
||||
return clsid2catid, catid2name
|
||||
|
||||
# anno file not exist, load default categories of COCO17
|
||||
else:
|
||||
if metric_type.lower() == 'rbox':
|
||||
logger.warning(
|
||||
"metric_type: {}, load default categories of DOTA.".format(
|
||||
metric_type))
|
||||
return _dota_category()
|
||||
logger.warning("metric_type: {}, load default categories of COCO.".
|
||||
format(metric_type))
|
||||
return _coco17_category()
|
||||
|
||||
elif metric_type.lower() == 'voc':
|
||||
if anno_file and os.path.isfile(anno_file):
|
||||
cats = []
|
||||
with open(anno_file) as f:
|
||||
for line in f.readlines():
|
||||
cats.append(line.strip())
|
||||
|
||||
if cats[0] == 'background':
|
||||
cats = cats[1:]
|
||||
|
||||
clsid2catid = {i: i for i in range(len(cats))}
|
||||
catid2name = {i: name for i, name in enumerate(cats)}
|
||||
|
||||
return clsid2catid, catid2name
|
||||
|
||||
# anno file not exist, load default categories of
|
||||
# VOC all 20 categories
|
||||
else:
|
||||
logger.warning("metric_type: {}, load default categories of VOC.".
|
||||
format(metric_type))
|
||||
return _vocall_category()
|
||||
|
||||
elif metric_type.lower() == 'oid':
|
||||
if anno_file and os.path.isfile(anno_file):
|
||||
logger.warning("only default categories support for OID19")
|
||||
return _oid19_category()
|
||||
|
||||
elif metric_type.lower() == 'keypointtopdowncocoeval' or metric_type.lower(
|
||||
) == 'keypointtopdownmpiieval':
|
||||
return (None, {'id': 'keypoint'})
|
||||
|
||||
elif metric_type.lower() == 'pose3deval':
|
||||
return (None, {'id': 'pose3d'})
|
||||
|
||||
elif metric_type.lower() in ['mot', 'motdet', 'reid']:
|
||||
if anno_file and os.path.isfile(anno_file):
|
||||
cats = []
|
||||
with open(anno_file) as f:
|
||||
for line in f.readlines():
|
||||
cats.append(line.strip())
|
||||
if cats[0] == 'background':
|
||||
cats = cats[1:]
|
||||
clsid2catid = {i: i for i in range(len(cats))}
|
||||
catid2name = {i: name for i, name in enumerate(cats)}
|
||||
return clsid2catid, catid2name
|
||||
# anno file not exist, load default category 'pedestrian'.
|
||||
else:
|
||||
logger.warning(
|
||||
"metric_type: {}, load default categories of pedestrian MOT.".
|
||||
format(metric_type))
|
||||
return _mot_category(category='pedestrian')
|
||||
|
||||
elif metric_type.lower() in ['kitti', 'bdd100kmot']:
|
||||
return _mot_category(category='vehicle')
|
||||
|
||||
elif metric_type.lower() in ['mcmot']:
|
||||
if anno_file and os.path.isfile(anno_file):
|
||||
cats = []
|
||||
with open(anno_file) as f:
|
||||
for line in f.readlines():
|
||||
cats.append(line.strip())
|
||||
if cats[0] == 'background':
|
||||
cats = cats[1:]
|
||||
clsid2catid = {i: i for i in range(len(cats))}
|
||||
catid2name = {i: name for i, name in enumerate(cats)}
|
||||
return clsid2catid, catid2name
|
||||
# anno file not exist, load default categories of visdrone all 10 categories
|
||||
else:
|
||||
logger.warning(
|
||||
"metric_type: {}, load default categories of VisDrone.".format(
|
||||
metric_type))
|
||||
return _visdrone_category()
|
||||
|
||||
else:
|
||||
raise ValueError("unknown metric type {}".format(metric_type))
|
||||
|
||||
|
||||
def _mot_category(category='pedestrian'):
|
||||
"""
|
||||
Get class id to category id map and category id
|
||||
to category name map of mot dataset
|
||||
"""
|
||||
label_map = {category: 0}
|
||||
label_map = sorted(label_map.items(), key=lambda x: x[1])
|
||||
cats = [l[0] for l in label_map]
|
||||
|
||||
clsid2catid = {i: i for i in range(len(cats))}
|
||||
catid2name = {i: name for i, name in enumerate(cats)}
|
||||
|
||||
return clsid2catid, catid2name
|
||||
|
||||
|
||||
def _coco17_category():
|
||||
"""
|
||||
Get class id to category id map and category id
|
||||
to category name map of COCO2017 dataset
|
||||
|
||||
"""
|
||||
clsid2catid = {
|
||||
1: 1,
|
||||
2: 2,
|
||||
3: 3,
|
||||
4: 4,
|
||||
5: 5,
|
||||
6: 6,
|
||||
7: 7,
|
||||
8: 8,
|
||||
9: 9,
|
||||
10: 10,
|
||||
11: 11,
|
||||
12: 13,
|
||||
13: 14,
|
||||
14: 15,
|
||||
15: 16,
|
||||
16: 17,
|
||||
17: 18,
|
||||
18: 19,
|
||||
19: 20,
|
||||
20: 21,
|
||||
21: 22,
|
||||
22: 23,
|
||||
23: 24,
|
||||
24: 25,
|
||||
25: 27,
|
||||
26: 28,
|
||||
27: 31,
|
||||
28: 32,
|
||||
29: 33,
|
||||
30: 34,
|
||||
31: 35,
|
||||
32: 36,
|
||||
33: 37,
|
||||
34: 38,
|
||||
35: 39,
|
||||
36: 40,
|
||||
37: 41,
|
||||
38: 42,
|
||||
39: 43,
|
||||
40: 44,
|
||||
41: 46,
|
||||
42: 47,
|
||||
43: 48,
|
||||
44: 49,
|
||||
45: 50,
|
||||
46: 51,
|
||||
47: 52,
|
||||
48: 53,
|
||||
49: 54,
|
||||
50: 55,
|
||||
51: 56,
|
||||
52: 57,
|
||||
53: 58,
|
||||
54: 59,
|
||||
55: 60,
|
||||
56: 61,
|
||||
57: 62,
|
||||
58: 63,
|
||||
59: 64,
|
||||
60: 65,
|
||||
61: 67,
|
||||
62: 70,
|
||||
63: 72,
|
||||
64: 73,
|
||||
65: 74,
|
||||
66: 75,
|
||||
67: 76,
|
||||
68: 77,
|
||||
69: 78,
|
||||
70: 79,
|
||||
71: 80,
|
||||
72: 81,
|
||||
73: 82,
|
||||
74: 84,
|
||||
75: 85,
|
||||
76: 86,
|
||||
77: 87,
|
||||
78: 88,
|
||||
79: 89,
|
||||
80: 90
|
||||
}
|
||||
|
||||
catid2name = {
|
||||
0: 'background',
|
||||
1: 'person',
|
||||
2: 'bicycle',
|
||||
3: 'car',
|
||||
4: 'motorcycle',
|
||||
5: 'airplane',
|
||||
6: 'bus',
|
||||
7: 'train',
|
||||
8: 'truck',
|
||||
9: 'boat',
|
||||
10: 'traffic light',
|
||||
11: 'fire hydrant',
|
||||
13: 'stop sign',
|
||||
14: 'parking meter',
|
||||
15: 'bench',
|
||||
16: 'bird',
|
||||
17: 'cat',
|
||||
18: 'dog',
|
||||
19: 'horse',
|
||||
20: 'sheep',
|
||||
21: 'cow',
|
||||
22: 'elephant',
|
||||
23: 'bear',
|
||||
24: 'zebra',
|
||||
25: 'giraffe',
|
||||
27: 'backpack',
|
||||
28: 'umbrella',
|
||||
31: 'handbag',
|
||||
32: 'tie',
|
||||
33: 'suitcase',
|
||||
34: 'frisbee',
|
||||
35: 'skis',
|
||||
36: 'snowboard',
|
||||
37: 'sports ball',
|
||||
38: 'kite',
|
||||
39: 'baseball bat',
|
||||
40: 'baseball glove',
|
||||
41: 'skateboard',
|
||||
42: 'surfboard',
|
||||
43: 'tennis racket',
|
||||
44: 'bottle',
|
||||
46: 'wine glass',
|
||||
47: 'cup',
|
||||
48: 'fork',
|
||||
49: 'knife',
|
||||
50: 'spoon',
|
||||
51: 'bowl',
|
||||
52: 'banana',
|
||||
53: 'apple',
|
||||
54: 'sandwich',
|
||||
55: 'orange',
|
||||
56: 'broccoli',
|
||||
57: 'carrot',
|
||||
58: 'hot dog',
|
||||
59: 'pizza',
|
||||
60: 'donut',
|
||||
61: 'cake',
|
||||
62: 'chair',
|
||||
63: 'couch',
|
||||
64: 'potted plant',
|
||||
65: 'bed',
|
||||
67: 'dining table',
|
||||
70: 'toilet',
|
||||
72: 'tv',
|
||||
73: 'laptop',
|
||||
74: 'mouse',
|
||||
75: 'remote',
|
||||
76: 'keyboard',
|
||||
77: 'cell phone',
|
||||
78: 'microwave',
|
||||
79: 'oven',
|
||||
80: 'toaster',
|
||||
81: 'sink',
|
||||
82: 'refrigerator',
|
||||
84: 'book',
|
||||
85: 'clock',
|
||||
86: 'vase',
|
||||
87: 'scissors',
|
||||
88: 'teddy bear',
|
||||
89: 'hair drier',
|
||||
90: 'toothbrush'
|
||||
}
|
||||
|
||||
clsid2catid = {k - 1: v for k, v in clsid2catid.items()}
|
||||
catid2name.pop(0)
|
||||
|
||||
return clsid2catid, catid2name
|
||||
|
||||
|
||||
def _dota_category():
|
||||
"""
|
||||
Get class id to category id map and category id
|
||||
to category name map of dota dataset
|
||||
"""
|
||||
catid2name = {
|
||||
0: 'background',
|
||||
1: 'plane',
|
||||
2: 'baseball-diamond',
|
||||
3: 'bridge',
|
||||
4: 'ground-track-field',
|
||||
5: 'small-vehicle',
|
||||
6: 'large-vehicle',
|
||||
7: 'ship',
|
||||
8: 'tennis-court',
|
||||
9: 'basketball-court',
|
||||
10: 'storage-tank',
|
||||
11: 'soccer-ball-field',
|
||||
12: 'roundabout',
|
||||
13: 'harbor',
|
||||
14: 'swimming-pool',
|
||||
15: 'helicopter'
|
||||
}
|
||||
catid2name.pop(0)
|
||||
clsid2catid = {i: i + 1 for i in range(len(catid2name))}
|
||||
return clsid2catid, catid2name
|
||||
|
||||
|
||||
def _vocall_category():
|
||||
"""
|
||||
Get class id to category id map and category id
|
||||
to category name map of mixup voc dataset
|
||||
|
||||
"""
|
||||
label_map = pascalvoc_label()
|
||||
label_map = sorted(label_map.items(), key=lambda x: x[1])
|
||||
cats = [l[0] for l in label_map]
|
||||
|
||||
clsid2catid = {i: i for i in range(len(cats))}
|
||||
catid2name = {i: name for i, name in enumerate(cats)}
|
||||
|
||||
return clsid2catid, catid2name
|
||||
|
||||
|
||||
def _oid19_category():
|
||||
clsid2catid = {k: k + 1 for k in range(500)}
|
||||
|
||||
catid2name = {
|
||||
0: "background",
|
||||
1: "Infant bed",
|
||||
2: "Rose",
|
||||
3: "Flag",
|
||||
4: "Flashlight",
|
||||
5: "Sea turtle",
|
||||
6: "Camera",
|
||||
7: "Animal",
|
||||
8: "Glove",
|
||||
9: "Crocodile",
|
||||
10: "Cattle",
|
||||
11: "House",
|
||||
12: "Guacamole",
|
||||
13: "Penguin",
|
||||
14: "Vehicle registration plate",
|
||||
15: "Bench",
|
||||
16: "Ladybug",
|
||||
17: "Human nose",
|
||||
18: "Watermelon",
|
||||
19: "Flute",
|
||||
20: "Butterfly",
|
||||
21: "Washing machine",
|
||||
22: "Raccoon",
|
||||
23: "Segway",
|
||||
24: "Taco",
|
||||
25: "Jellyfish",
|
||||
26: "Cake",
|
||||
27: "Pen",
|
||||
28: "Cannon",
|
||||
29: "Bread",
|
||||
30: "Tree",
|
||||
31: "Shellfish",
|
||||
32: "Bed",
|
||||
33: "Hamster",
|
||||
34: "Hat",
|
||||
35: "Toaster",
|
||||
36: "Sombrero",
|
||||
37: "Tiara",
|
||||
38: "Bowl",
|
||||
39: "Dragonfly",
|
||||
40: "Moths and butterflies",
|
||||
41: "Antelope",
|
||||
42: "Vegetable",
|
||||
43: "Torch",
|
||||
44: "Building",
|
||||
45: "Power plugs and sockets",
|
||||
46: "Blender",
|
||||
47: "Billiard table",
|
||||
48: "Cutting board",
|
||||
49: "Bronze sculpture",
|
||||
50: "Turtle",
|
||||
51: "Broccoli",
|
||||
52: "Tiger",
|
||||
53: "Mirror",
|
||||
54: "Bear",
|
||||
55: "Zucchini",
|
||||
56: "Dress",
|
||||
57: "Volleyball",
|
||||
58: "Guitar",
|
||||
59: "Reptile",
|
||||
60: "Golf cart",
|
||||
61: "Tart",
|
||||
62: "Fedora",
|
||||
63: "Carnivore",
|
||||
64: "Car",
|
||||
65: "Lighthouse",
|
||||
66: "Coffeemaker",
|
||||
67: "Food processor",
|
||||
68: "Truck",
|
||||
69: "Bookcase",
|
||||
70: "Surfboard",
|
||||
71: "Footwear",
|
||||
72: "Bench",
|
||||
73: "Necklace",
|
||||
74: "Flower",
|
||||
75: "Radish",
|
||||
76: "Marine mammal",
|
||||
77: "Frying pan",
|
||||
78: "Tap",
|
||||
79: "Peach",
|
||||
80: "Knife",
|
||||
81: "Handbag",
|
||||
82: "Laptop",
|
||||
83: "Tent",
|
||||
84: "Ambulance",
|
||||
85: "Christmas tree",
|
||||
86: "Eagle",
|
||||
87: "Limousine",
|
||||
88: "Kitchen & dining room table",
|
||||
89: "Polar bear",
|
||||
90: "Tower",
|
||||
91: "Football",
|
||||
92: "Willow",
|
||||
93: "Human head",
|
||||
94: "Stop sign",
|
||||
95: "Banana",
|
||||
96: "Mixer",
|
||||
97: "Binoculars",
|
||||
98: "Dessert",
|
||||
99: "Bee",
|
||||
100: "Chair",
|
||||
101: "Wood-burning stove",
|
||||
102: "Flowerpot",
|
||||
103: "Beaker",
|
||||
104: "Oyster",
|
||||
105: "Woodpecker",
|
||||
106: "Harp",
|
||||
107: "Bathtub",
|
||||
108: "Wall clock",
|
||||
109: "Sports uniform",
|
||||
110: "Rhinoceros",
|
||||
111: "Beehive",
|
||||
112: "Cupboard",
|
||||
113: "Chicken",
|
||||
114: "Man",
|
||||
115: "Blue jay",
|
||||
116: "Cucumber",
|
||||
117: "Balloon",
|
||||
118: "Kite",
|
||||
119: "Fireplace",
|
||||
120: "Lantern",
|
||||
121: "Missile",
|
||||
122: "Book",
|
||||
123: "Spoon",
|
||||
124: "Grapefruit",
|
||||
125: "Squirrel",
|
||||
126: "Orange",
|
||||
127: "Coat",
|
||||
128: "Punching bag",
|
||||
129: "Zebra",
|
||||
130: "Billboard",
|
||||
131: "Bicycle",
|
||||
132: "Door handle",
|
||||
133: "Mechanical fan",
|
||||
134: "Ring binder",
|
||||
135: "Table",
|
||||
136: "Parrot",
|
||||
137: "Sock",
|
||||
138: "Vase",
|
||||
139: "Weapon",
|
||||
140: "Shotgun",
|
||||
141: "Glasses",
|
||||
142: "Seahorse",
|
||||
143: "Belt",
|
||||
144: "Watercraft",
|
||||
145: "Window",
|
||||
146: "Giraffe",
|
||||
147: "Lion",
|
||||
148: "Tire",
|
||||
149: "Vehicle",
|
||||
150: "Canoe",
|
||||
151: "Tie",
|
||||
152: "Shelf",
|
||||
153: "Picture frame",
|
||||
154: "Printer",
|
||||
155: "Human leg",
|
||||
156: "Boat",
|
||||
157: "Slow cooker",
|
||||
158: "Croissant",
|
||||
159: "Candle",
|
||||
160: "Pancake",
|
||||
161: "Pillow",
|
||||
162: "Coin",
|
||||
163: "Stretcher",
|
||||
164: "Sandal",
|
||||
165: "Woman",
|
||||
166: "Stairs",
|
||||
167: "Harpsichord",
|
||||
168: "Stool",
|
||||
169: "Bus",
|
||||
170: "Suitcase",
|
||||
171: "Human mouth",
|
||||
172: "Juice",
|
||||
173: "Skull",
|
||||
174: "Door",
|
||||
175: "Violin",
|
||||
176: "Chopsticks",
|
||||
177: "Digital clock",
|
||||
178: "Sunflower",
|
||||
179: "Leopard",
|
||||
180: "Bell pepper",
|
||||
181: "Harbor seal",
|
||||
182: "Snake",
|
||||
183: "Sewing machine",
|
||||
184: "Goose",
|
||||
185: "Helicopter",
|
||||
186: "Seat belt",
|
||||
187: "Coffee cup",
|
||||
188: "Microwave oven",
|
||||
189: "Hot dog",
|
||||
190: "Countertop",
|
||||
191: "Serving tray",
|
||||
192: "Dog bed",
|
||||
193: "Beer",
|
||||
194: "Sunglasses",
|
||||
195: "Golf ball",
|
||||
196: "Waffle",
|
||||
197: "Palm tree",
|
||||
198: "Trumpet",
|
||||
199: "Ruler",
|
||||
200: "Helmet",
|
||||
201: "Ladder",
|
||||
202: "Office building",
|
||||
203: "Tablet computer",
|
||||
204: "Toilet paper",
|
||||
205: "Pomegranate",
|
||||
206: "Skirt",
|
||||
207: "Gas stove",
|
||||
208: "Cookie",
|
||||
209: "Cart",
|
||||
210: "Raven",
|
||||
211: "Egg",
|
||||
212: "Burrito",
|
||||
213: "Goat",
|
||||
214: "Kitchen knife",
|
||||
215: "Skateboard",
|
||||
216: "Salt and pepper shakers",
|
||||
217: "Lynx",
|
||||
218: "Boot",
|
||||
219: "Platter",
|
||||
220: "Ski",
|
||||
221: "Swimwear",
|
||||
222: "Swimming pool",
|
||||
223: "Drinking straw",
|
||||
224: "Wrench",
|
||||
225: "Drum",
|
||||
226: "Ant",
|
||||
227: "Human ear",
|
||||
228: "Headphones",
|
||||
229: "Fountain",
|
||||
230: "Bird",
|
||||
231: "Jeans",
|
||||
232: "Television",
|
||||
233: "Crab",
|
||||
234: "Microphone",
|
||||
235: "Home appliance",
|
||||
236: "Snowplow",
|
||||
237: "Beetle",
|
||||
238: "Artichoke",
|
||||
239: "Jet ski",
|
||||
240: "Stationary bicycle",
|
||||
241: "Human hair",
|
||||
242: "Brown bear",
|
||||
243: "Starfish",
|
||||
244: "Fork",
|
||||
245: "Lobster",
|
||||
246: "Corded phone",
|
||||
247: "Drink",
|
||||
248: "Saucer",
|
||||
249: "Carrot",
|
||||
250: "Insect",
|
||||
251: "Clock",
|
||||
252: "Castle",
|
||||
253: "Tennis racket",
|
||||
254: "Ceiling fan",
|
||||
255: "Asparagus",
|
||||
256: "Jaguar",
|
||||
257: "Musical instrument",
|
||||
258: "Train",
|
||||
259: "Cat",
|
||||
260: "Rifle",
|
||||
261: "Dumbbell",
|
||||
262: "Mobile phone",
|
||||
263: "Taxi",
|
||||
264: "Shower",
|
||||
265: "Pitcher",
|
||||
266: "Lemon",
|
||||
267: "Invertebrate",
|
||||
268: "Turkey",
|
||||
269: "High heels",
|
||||
270: "Bust",
|
||||
271: "Elephant",
|
||||
272: "Scarf",
|
||||
273: "Barrel",
|
||||
274: "Trombone",
|
||||
275: "Pumpkin",
|
||||
276: "Box",
|
||||
277: "Tomato",
|
||||
278: "Frog",
|
||||
279: "Bidet",
|
||||
280: "Human face",
|
||||
281: "Houseplant",
|
||||
282: "Van",
|
||||
283: "Shark",
|
||||
284: "Ice cream",
|
||||
285: "Swim cap",
|
||||
286: "Falcon",
|
||||
287: "Ostrich",
|
||||
288: "Handgun",
|
||||
289: "Whiteboard",
|
||||
290: "Lizard",
|
||||
291: "Pasta",
|
||||
292: "Snowmobile",
|
||||
293: "Light bulb",
|
||||
294: "Window blind",
|
||||
295: "Muffin",
|
||||
296: "Pretzel",
|
||||
297: "Computer monitor",
|
||||
298: "Horn",
|
||||
299: "Furniture",
|
||||
300: "Sandwich",
|
||||
301: "Fox",
|
||||
302: "Convenience store",
|
||||
303: "Fish",
|
||||
304: "Fruit",
|
||||
305: "Earrings",
|
||||
306: "Curtain",
|
||||
307: "Grape",
|
||||
308: "Sofa bed",
|
||||
309: "Horse",
|
||||
310: "Luggage and bags",
|
||||
311: "Desk",
|
||||
312: "Crutch",
|
||||
313: "Bicycle helmet",
|
||||
314: "Tick",
|
||||
315: "Airplane",
|
||||
316: "Canary",
|
||||
317: "Spatula",
|
||||
318: "Watch",
|
||||
319: "Lily",
|
||||
320: "Kitchen appliance",
|
||||
321: "Filing cabinet",
|
||||
322: "Aircraft",
|
||||
323: "Cake stand",
|
||||
324: "Candy",
|
||||
325: "Sink",
|
||||
326: "Mouse",
|
||||
327: "Wine",
|
||||
328: "Wheelchair",
|
||||
329: "Goldfish",
|
||||
330: "Refrigerator",
|
||||
331: "French fries",
|
||||
332: "Drawer",
|
||||
333: "Treadmill",
|
||||
334: "Picnic basket",
|
||||
335: "Dice",
|
||||
336: "Cabbage",
|
||||
337: "Football helmet",
|
||||
338: "Pig",
|
||||
339: "Person",
|
||||
340: "Shorts",
|
||||
341: "Gondola",
|
||||
342: "Honeycomb",
|
||||
343: "Doughnut",
|
||||
344: "Chest of drawers",
|
||||
345: "Land vehicle",
|
||||
346: "Bat",
|
||||
347: "Monkey",
|
||||
348: "Dagger",
|
||||
349: "Tableware",
|
||||
350: "Human foot",
|
||||
351: "Mug",
|
||||
352: "Alarm clock",
|
||||
353: "Pressure cooker",
|
||||
354: "Human hand",
|
||||
355: "Tortoise",
|
||||
356: "Baseball glove",
|
||||
357: "Sword",
|
||||
358: "Pear",
|
||||
359: "Miniskirt",
|
||||
360: "Traffic sign",
|
||||
361: "Girl",
|
||||
362: "Roller skates",
|
||||
363: "Dinosaur",
|
||||
364: "Porch",
|
||||
365: "Human beard",
|
||||
366: "Submarine sandwich",
|
||||
367: "Screwdriver",
|
||||
368: "Strawberry",
|
||||
369: "Wine glass",
|
||||
370: "Seafood",
|
||||
371: "Racket",
|
||||
372: "Wheel",
|
||||
373: "Sea lion",
|
||||
374: "Toy",
|
||||
375: "Tea",
|
||||
376: "Tennis ball",
|
||||
377: "Waste container",
|
||||
378: "Mule",
|
||||
379: "Cricket ball",
|
||||
380: "Pineapple",
|
||||
381: "Coconut",
|
||||
382: "Doll",
|
||||
383: "Coffee table",
|
||||
384: "Snowman",
|
||||
385: "Lavender",
|
||||
386: "Shrimp",
|
||||
387: "Maple",
|
||||
388: "Cowboy hat",
|
||||
389: "Goggles",
|
||||
390: "Rugby ball",
|
||||
391: "Caterpillar",
|
||||
392: "Poster",
|
||||
393: "Rocket",
|
||||
394: "Organ",
|
||||
395: "Saxophone",
|
||||
396: "Traffic light",
|
||||
397: "Cocktail",
|
||||
398: "Plastic bag",
|
||||
399: "Squash",
|
||||
400: "Mushroom",
|
||||
401: "Hamburger",
|
||||
402: "Light switch",
|
||||
403: "Parachute",
|
||||
404: "Teddy bear",
|
||||
405: "Winter melon",
|
||||
406: "Deer",
|
||||
407: "Musical keyboard",
|
||||
408: "Plumbing fixture",
|
||||
409: "Scoreboard",
|
||||
410: "Baseball bat",
|
||||
411: "Envelope",
|
||||
412: "Adhesive tape",
|
||||
413: "Briefcase",
|
||||
414: "Paddle",
|
||||
415: "Bow and arrow",
|
||||
416: "Telephone",
|
||||
417: "Sheep",
|
||||
418: "Jacket",
|
||||
419: "Boy",
|
||||
420: "Pizza",
|
||||
421: "Otter",
|
||||
422: "Office supplies",
|
||||
423: "Couch",
|
||||
424: "Cello",
|
||||
425: "Bull",
|
||||
426: "Camel",
|
||||
427: "Ball",
|
||||
428: "Duck",
|
||||
429: "Whale",
|
||||
430: "Shirt",
|
||||
431: "Tank",
|
||||
432: "Motorcycle",
|
||||
433: "Accordion",
|
||||
434: "Owl",
|
||||
435: "Porcupine",
|
||||
436: "Sun hat",
|
||||
437: "Nail",
|
||||
438: "Scissors",
|
||||
439: "Swan",
|
||||
440: "Lamp",
|
||||
441: "Crown",
|
||||
442: "Piano",
|
||||
443: "Sculpture",
|
||||
444: "Cheetah",
|
||||
445: "Oboe",
|
||||
446: "Tin can",
|
||||
447: "Mango",
|
||||
448: "Tripod",
|
||||
449: "Oven",
|
||||
450: "Mouse",
|
||||
451: "Barge",
|
||||
452: "Coffee",
|
||||
453: "Snowboard",
|
||||
454: "Common fig",
|
||||
455: "Salad",
|
||||
456: "Marine invertebrates",
|
||||
457: "Umbrella",
|
||||
458: "Kangaroo",
|
||||
459: "Human arm",
|
||||
460: "Measuring cup",
|
||||
461: "Snail",
|
||||
462: "Loveseat",
|
||||
463: "Suit",
|
||||
464: "Teapot",
|
||||
465: "Bottle",
|
||||
466: "Alpaca",
|
||||
467: "Kettle",
|
||||
468: "Trousers",
|
||||
469: "Popcorn",
|
||||
470: "Centipede",
|
||||
471: "Spider",
|
||||
472: "Sparrow",
|
||||
473: "Plate",
|
||||
474: "Bagel",
|
||||
475: "Personal care",
|
||||
476: "Apple",
|
||||
477: "Brassiere",
|
||||
478: "Bathroom cabinet",
|
||||
479: "studio couch",
|
||||
480: "Computer keyboard",
|
||||
481: "Table tennis racket",
|
||||
482: "Sushi",
|
||||
483: "Cabinetry",
|
||||
484: "Street light",
|
||||
485: "Towel",
|
||||
486: "Nightstand",
|
||||
487: "Rabbit",
|
||||
488: "Dolphin",
|
||||
489: "Dog",
|
||||
490: "Jug",
|
||||
491: "Wok",
|
||||
492: "Fire hydrant",
|
||||
493: "Human eye",
|
||||
494: "Skyscraper",
|
||||
495: "Backpack",
|
||||
496: "Potato",
|
||||
497: "Paper towel",
|
||||
498: "Lifejacket",
|
||||
499: "Bicycle wheel",
|
||||
500: "Toilet",
|
||||
}
|
||||
|
||||
return clsid2catid, catid2name
|
||||
|
||||
|
||||
def _visdrone_category():
|
||||
clsid2catid = {i: i for i in range(10)}
|
||||
|
||||
catid2name = {
|
||||
0: 'pedestrian',
|
||||
1: 'people',
|
||||
2: 'bicycle',
|
||||
3: 'car',
|
||||
4: 'van',
|
||||
5: 'truck',
|
||||
6: 'tricycle',
|
||||
7: 'awning-tricycle',
|
||||
8: 'bus',
|
||||
9: 'motor'
|
||||
}
|
||||
return clsid2catid, catid2name
|
||||
587
rtdetr_paddle/ppdet/data/source/coco.py
Normal file
587
rtdetr_paddle/ppdet/data/source/coco.py
Normal file
@@ -0,0 +1,587 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import copy
|
||||
try:
|
||||
from collections.abc import Sequence
|
||||
except Exception:
|
||||
from collections import Sequence
|
||||
import numpy as np
|
||||
from ppdet.core.workspace import register, serializable
|
||||
from .dataset import DetDataset
|
||||
|
||||
from ppdet.utils.logger import setup_logger
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
__all__ = ['COCODataSet', 'SlicedCOCODataSet', 'SemiCOCODataSet']
|
||||
|
||||
|
||||
@register
|
||||
@serializable
|
||||
class COCODataSet(DetDataset):
|
||||
"""
|
||||
Load dataset with COCO format.
|
||||
|
||||
Args:
|
||||
dataset_dir (str): root directory for dataset.
|
||||
image_dir (str): directory for images.
|
||||
anno_path (str): coco annotation file path.
|
||||
data_fields (list): key name of data dictionary, at least have 'image'.
|
||||
sample_num (int): number of samples to load, -1 means all.
|
||||
load_crowd (bool): whether to load crowded ground-truth.
|
||||
False as default
|
||||
allow_empty (bool): whether to load empty entry. False as default
|
||||
empty_ratio (float): the ratio of empty record number to total
|
||||
record's, if empty_ratio is out of [0. ,1.), do not sample the
|
||||
records and use all the empty entries. 1. as default
|
||||
repeat (int): repeat times for dataset, use in benchmark.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dataset_dir=None,
|
||||
image_dir=None,
|
||||
anno_path=None,
|
||||
data_fields=['image'],
|
||||
sample_num=-1,
|
||||
load_crowd=False,
|
||||
allow_empty=False,
|
||||
empty_ratio=1.,
|
||||
repeat=1):
|
||||
super(COCODataSet, self).__init__(
|
||||
dataset_dir,
|
||||
image_dir,
|
||||
anno_path,
|
||||
data_fields,
|
||||
sample_num,
|
||||
repeat=repeat)
|
||||
self.load_image_only = False
|
||||
self.load_semantic = False
|
||||
self.load_crowd = load_crowd
|
||||
self.allow_empty = allow_empty
|
||||
self.empty_ratio = empty_ratio
|
||||
|
||||
def _sample_empty(self, records, num):
|
||||
# if empty_ratio is out of [0. ,1.), do not sample the records
|
||||
if self.empty_ratio < 0. or self.empty_ratio >= 1.:
|
||||
return records
|
||||
import random
|
||||
sample_num = min(
|
||||
int(num * self.empty_ratio / (1 - self.empty_ratio)), len(records))
|
||||
records = random.sample(records, sample_num)
|
||||
return records
|
||||
|
||||
def parse_dataset(self):
|
||||
anno_path = os.path.join(self.dataset_dir, self.anno_path)
|
||||
image_dir = os.path.join(self.dataset_dir, self.image_dir)
|
||||
|
||||
assert anno_path.endswith('.json'), \
|
||||
'invalid coco annotation file: ' + anno_path
|
||||
from pycocotools.coco import COCO
|
||||
coco = COCO(anno_path)
|
||||
img_ids = coco.getImgIds()
|
||||
img_ids.sort()
|
||||
cat_ids = coco.getCatIds()
|
||||
records = []
|
||||
empty_records = []
|
||||
ct = 0
|
||||
|
||||
self.catid2clsid = dict({catid: i for i, catid in enumerate(cat_ids)})
|
||||
self.cname2cid = dict({
|
||||
coco.loadCats(catid)[0]['name']: clsid
|
||||
for catid, clsid in self.catid2clsid.items()
|
||||
})
|
||||
|
||||
if 'annotations' not in coco.dataset:
|
||||
self.load_image_only = True
|
||||
logger.warning('Annotation file: {} does not contains ground truth '
|
||||
'and load image information only.'.format(anno_path))
|
||||
|
||||
for img_id in img_ids:
|
||||
img_anno = coco.loadImgs([img_id])[0]
|
||||
im_fname = img_anno['file_name']
|
||||
im_w = float(img_anno['width'])
|
||||
im_h = float(img_anno['height'])
|
||||
|
||||
im_path = os.path.join(image_dir,
|
||||
im_fname) if image_dir else im_fname
|
||||
is_empty = False
|
||||
if not os.path.exists(im_path):
|
||||
logger.warning('Illegal image file: {}, and it will be '
|
||||
'ignored'.format(im_path))
|
||||
continue
|
||||
|
||||
if im_w < 0 or im_h < 0:
|
||||
logger.warning('Illegal width: {} or height: {} in annotation, '
|
||||
'and im_id: {} will be ignored'.format(
|
||||
im_w, im_h, img_id))
|
||||
continue
|
||||
|
||||
coco_rec = {
|
||||
'im_file': im_path,
|
||||
'im_id': np.array([img_id]),
|
||||
'h': im_h,
|
||||
'w': im_w,
|
||||
} if 'image' in self.data_fields else {}
|
||||
|
||||
if not self.load_image_only:
|
||||
ins_anno_ids = coco.getAnnIds(
|
||||
imgIds=[img_id], iscrowd=None if self.load_crowd else False)
|
||||
instances = coco.loadAnns(ins_anno_ids)
|
||||
|
||||
bboxes = []
|
||||
is_rbox_anno = False
|
||||
for inst in instances:
|
||||
# check gt bbox
|
||||
if inst.get('ignore', False):
|
||||
continue
|
||||
if 'bbox' not in inst.keys():
|
||||
continue
|
||||
else:
|
||||
if not any(np.array(inst['bbox'])):
|
||||
continue
|
||||
|
||||
x1, y1, box_w, box_h = inst['bbox']
|
||||
x2 = x1 + box_w
|
||||
y2 = y1 + box_h
|
||||
eps = 1e-5
|
||||
if inst['area'] > 0 and x2 - x1 > eps and y2 - y1 > eps:
|
||||
inst['clean_bbox'] = [
|
||||
round(float(x), 3) for x in [x1, y1, x2, y2]
|
||||
]
|
||||
bboxes.append(inst)
|
||||
else:
|
||||
logger.warning(
|
||||
'Found an invalid bbox in annotations: im_id: {}, '
|
||||
'area: {} x1: {}, y1: {}, x2: {}, y2: {}.'.format(
|
||||
img_id, float(inst['area']), x1, y1, x2, y2))
|
||||
|
||||
num_bbox = len(bboxes)
|
||||
if num_bbox <= 0 and not self.allow_empty:
|
||||
continue
|
||||
elif num_bbox <= 0:
|
||||
is_empty = True
|
||||
|
||||
gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32)
|
||||
gt_class = np.zeros((num_bbox, 1), dtype=np.int32)
|
||||
is_crowd = np.zeros((num_bbox, 1), dtype=np.int32)
|
||||
gt_poly = [None] * num_bbox
|
||||
gt_track_id = -np.ones((num_bbox, 1), dtype=np.int32)
|
||||
|
||||
has_segmentation = False
|
||||
has_track_id = False
|
||||
for i, box in enumerate(bboxes):
|
||||
catid = box['category_id']
|
||||
gt_class[i][0] = self.catid2clsid[catid]
|
||||
gt_bbox[i, :] = box['clean_bbox']
|
||||
is_crowd[i][0] = box['iscrowd']
|
||||
# check RLE format
|
||||
if 'segmentation' in box and box['iscrowd'] == 1:
|
||||
gt_poly[i] = [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]
|
||||
elif 'segmentation' in box and box['segmentation']:
|
||||
if not np.array(
|
||||
box['segmentation'],
|
||||
dtype=object).size > 0 and not self.allow_empty:
|
||||
bboxes.pop(i)
|
||||
gt_poly.pop(i)
|
||||
np.delete(is_crowd, i)
|
||||
np.delete(gt_class, i)
|
||||
np.delete(gt_bbox, i)
|
||||
else:
|
||||
gt_poly[i] = box['segmentation']
|
||||
has_segmentation = True
|
||||
|
||||
if 'track_id' in box:
|
||||
gt_track_id[i][0] = box['track_id']
|
||||
has_track_id = True
|
||||
|
||||
if has_segmentation and not any(
|
||||
gt_poly) and not self.allow_empty:
|
||||
continue
|
||||
|
||||
gt_rec = {
|
||||
'is_crowd': is_crowd,
|
||||
'gt_class': gt_class,
|
||||
'gt_bbox': gt_bbox,
|
||||
'gt_poly': gt_poly,
|
||||
}
|
||||
if has_track_id:
|
||||
gt_rec.update({'gt_track_id': gt_track_id})
|
||||
|
||||
for k, v in gt_rec.items():
|
||||
if k in self.data_fields:
|
||||
coco_rec[k] = v
|
||||
|
||||
# TODO: remove load_semantic
|
||||
if self.load_semantic and 'semantic' in self.data_fields:
|
||||
seg_path = os.path.join(self.dataset_dir, 'stuffthingmaps',
|
||||
'train2017', im_fname[:-3] + 'png')
|
||||
coco_rec.update({'semantic': seg_path})
|
||||
|
||||
logger.debug('Load file: {}, im_id: {}, h: {}, w: {}.'.format(
|
||||
im_path, img_id, im_h, im_w))
|
||||
if is_empty:
|
||||
empty_records.append(coco_rec)
|
||||
else:
|
||||
records.append(coco_rec)
|
||||
ct += 1
|
||||
if self.sample_num > 0 and ct >= self.sample_num:
|
||||
break
|
||||
assert ct > 0, 'not found any coco record in %s' % (anno_path)
|
||||
logger.info('Load [{} samples valid, {} samples invalid] in file {}.'.
|
||||
format(ct, len(img_ids) - ct, anno_path))
|
||||
if self.allow_empty and len(empty_records) > 0:
|
||||
empty_records = self._sample_empty(empty_records, len(records))
|
||||
records += empty_records
|
||||
self.roidbs = records
|
||||
|
||||
|
||||
@register
|
||||
@serializable
|
||||
class SlicedCOCODataSet(COCODataSet):
|
||||
"""Sliced COCODataSet"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_dir=None,
|
||||
image_dir=None,
|
||||
anno_path=None,
|
||||
data_fields=['image'],
|
||||
sample_num=-1,
|
||||
load_crowd=False,
|
||||
allow_empty=False,
|
||||
empty_ratio=1.,
|
||||
repeat=1,
|
||||
sliced_size=[640, 640],
|
||||
overlap_ratio=[0.25, 0.25], ):
|
||||
super(SlicedCOCODataSet, self).__init__(
|
||||
dataset_dir=dataset_dir,
|
||||
image_dir=image_dir,
|
||||
anno_path=anno_path,
|
||||
data_fields=data_fields,
|
||||
sample_num=sample_num,
|
||||
load_crowd=load_crowd,
|
||||
allow_empty=allow_empty,
|
||||
empty_ratio=empty_ratio,
|
||||
repeat=repeat, )
|
||||
self.sliced_size = sliced_size
|
||||
self.overlap_ratio = overlap_ratio
|
||||
|
||||
def parse_dataset(self):
|
||||
anno_path = os.path.join(self.dataset_dir, self.anno_path)
|
||||
image_dir = os.path.join(self.dataset_dir, self.image_dir)
|
||||
|
||||
assert anno_path.endswith('.json'), \
|
||||
'invalid coco annotation file: ' + anno_path
|
||||
from pycocotools.coco import COCO
|
||||
coco = COCO(anno_path)
|
||||
img_ids = coco.getImgIds()
|
||||
img_ids.sort()
|
||||
cat_ids = coco.getCatIds()
|
||||
records = []
|
||||
empty_records = []
|
||||
ct = 0
|
||||
ct_sub = 0
|
||||
|
||||
self.catid2clsid = dict({catid: i for i, catid in enumerate(cat_ids)})
|
||||
self.cname2cid = dict({
|
||||
coco.loadCats(catid)[0]['name']: clsid
|
||||
for catid, clsid in self.catid2clsid.items()
|
||||
})
|
||||
|
||||
if 'annotations' not in coco.dataset:
|
||||
self.load_image_only = True
|
||||
logger.warning('Annotation file: {} does not contains ground truth '
|
||||
'and load image information only.'.format(anno_path))
|
||||
try:
|
||||
import sahi
|
||||
from sahi.slicing import slice_image
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'sahi not found, plaese install sahi. '
|
||||
'for example: `pip install sahi`, see https://github.com/obss/sahi.'
|
||||
)
|
||||
raise e
|
||||
|
||||
sub_img_ids = 0
|
||||
for img_id in img_ids:
|
||||
img_anno = coco.loadImgs([img_id])[0]
|
||||
im_fname = img_anno['file_name']
|
||||
im_w = float(img_anno['width'])
|
||||
im_h = float(img_anno['height'])
|
||||
|
||||
im_path = os.path.join(image_dir,
|
||||
im_fname) if image_dir else im_fname
|
||||
is_empty = False
|
||||
if not os.path.exists(im_path):
|
||||
logger.warning('Illegal image file: {}, and it will be '
|
||||
'ignored'.format(im_path))
|
||||
continue
|
||||
|
||||
if im_w < 0 or im_h < 0:
|
||||
logger.warning('Illegal width: {} or height: {} in annotation, '
|
||||
'and im_id: {} will be ignored'.format(
|
||||
im_w, im_h, img_id))
|
||||
continue
|
||||
|
||||
slice_image_result = sahi.slicing.slice_image(
|
||||
image=im_path,
|
||||
slice_height=self.sliced_size[0],
|
||||
slice_width=self.sliced_size[1],
|
||||
overlap_height_ratio=self.overlap_ratio[0],
|
||||
overlap_width_ratio=self.overlap_ratio[1])
|
||||
|
||||
sub_img_num = len(slice_image_result)
|
||||
for _ind in range(sub_img_num):
|
||||
im = slice_image_result.images[_ind]
|
||||
coco_rec = {
|
||||
'image': im,
|
||||
'im_id': np.array([sub_img_ids + _ind]),
|
||||
'h': im.shape[0],
|
||||
'w': im.shape[1],
|
||||
'ori_im_id': np.array([img_id]),
|
||||
'st_pix': np.array(
|
||||
slice_image_result.starting_pixels[_ind],
|
||||
dtype=np.float32),
|
||||
'is_last': 1 if _ind == sub_img_num - 1 else 0,
|
||||
} if 'image' in self.data_fields else {}
|
||||
records.append(coco_rec)
|
||||
ct_sub += sub_img_num
|
||||
ct += 1
|
||||
if self.sample_num > 0 and ct >= self.sample_num:
|
||||
break
|
||||
assert ct > 0, 'not found any coco record in %s' % (anno_path)
|
||||
logger.info('{} samples and slice to {} sub_samples in file {}'.format(
|
||||
ct, ct_sub, anno_path))
|
||||
if self.allow_empty and len(empty_records) > 0:
|
||||
empty_records = self._sample_empty(empty_records, len(records))
|
||||
records += empty_records
|
||||
self.roidbs = records
|
||||
|
||||
|
||||
@register
|
||||
@serializable
|
||||
class SemiCOCODataSet(COCODataSet):
|
||||
"""Semi-COCODataSet used for supervised and unsupervised dataSet"""
|
||||
|
||||
def __init__(self,
|
||||
dataset_dir=None,
|
||||
image_dir=None,
|
||||
anno_path=None,
|
||||
data_fields=['image'],
|
||||
sample_num=-1,
|
||||
load_crowd=False,
|
||||
allow_empty=False,
|
||||
empty_ratio=1.,
|
||||
repeat=1,
|
||||
supervised=True):
|
||||
super(SemiCOCODataSet, self).__init__(
|
||||
dataset_dir, image_dir, anno_path, data_fields, sample_num,
|
||||
load_crowd, allow_empty, empty_ratio, repeat)
|
||||
self.supervised = supervised
|
||||
self.length = -1 # defalut -1 means all
|
||||
|
||||
def parse_dataset(self):
|
||||
anno_path = os.path.join(self.dataset_dir, self.anno_path)
|
||||
image_dir = os.path.join(self.dataset_dir, self.image_dir)
|
||||
|
||||
assert anno_path.endswith('.json'), \
|
||||
'invalid coco annotation file: ' + anno_path
|
||||
from pycocotools.coco import COCO
|
||||
coco = COCO(anno_path)
|
||||
img_ids = coco.getImgIds()
|
||||
img_ids.sort()
|
||||
cat_ids = coco.getCatIds()
|
||||
records = []
|
||||
empty_records = []
|
||||
ct = 0
|
||||
|
||||
self.catid2clsid = dict({catid: i for i, catid in enumerate(cat_ids)})
|
||||
self.cname2cid = dict({
|
||||
coco.loadCats(catid)[0]['name']: clsid
|
||||
for catid, clsid in self.catid2clsid.items()
|
||||
})
|
||||
|
||||
if 'annotations' not in coco.dataset or self.supervised == False:
|
||||
self.load_image_only = True
|
||||
logger.warning('Annotation file: {} does not contains ground truth '
|
||||
'and load image information only.'.format(anno_path))
|
||||
|
||||
for img_id in img_ids:
|
||||
img_anno = coco.loadImgs([img_id])[0]
|
||||
im_fname = img_anno['file_name']
|
||||
im_w = float(img_anno['width'])
|
||||
im_h = float(img_anno['height'])
|
||||
|
||||
im_path = os.path.join(image_dir,
|
||||
im_fname) if image_dir else im_fname
|
||||
is_empty = False
|
||||
if not os.path.exists(im_path):
|
||||
logger.warning('Illegal image file: {}, and it will be '
|
||||
'ignored'.format(im_path))
|
||||
continue
|
||||
|
||||
if im_w < 0 or im_h < 0:
|
||||
logger.warning('Illegal width: {} or height: {} in annotation, '
|
||||
'and im_id: {} will be ignored'.format(
|
||||
im_w, im_h, img_id))
|
||||
continue
|
||||
|
||||
coco_rec = {
|
||||
'im_file': im_path,
|
||||
'im_id': np.array([img_id]),
|
||||
'h': im_h,
|
||||
'w': im_w,
|
||||
} if 'image' in self.data_fields else {}
|
||||
|
||||
if not self.load_image_only:
|
||||
ins_anno_ids = coco.getAnnIds(
|
||||
imgIds=[img_id], iscrowd=None if self.load_crowd else False)
|
||||
instances = coco.loadAnns(ins_anno_ids)
|
||||
|
||||
bboxes = []
|
||||
is_rbox_anno = False
|
||||
for inst in instances:
|
||||
# check gt bbox
|
||||
if inst.get('ignore', False):
|
||||
continue
|
||||
if 'bbox' not in inst.keys():
|
||||
continue
|
||||
else:
|
||||
if not any(np.array(inst['bbox'])):
|
||||
continue
|
||||
|
||||
x1, y1, box_w, box_h = inst['bbox']
|
||||
x2 = x1 + box_w
|
||||
y2 = y1 + box_h
|
||||
eps = 1e-5
|
||||
if inst['area'] > 0 and x2 - x1 > eps and y2 - y1 > eps:
|
||||
inst['clean_bbox'] = [
|
||||
round(float(x), 3) for x in [x1, y1, x2, y2]
|
||||
]
|
||||
bboxes.append(inst)
|
||||
else:
|
||||
logger.warning(
|
||||
'Found an invalid bbox in annotations: im_id: {}, '
|
||||
'area: {} x1: {}, y1: {}, x2: {}, y2: {}.'.format(
|
||||
img_id, float(inst['area']), x1, y1, x2, y2))
|
||||
|
||||
num_bbox = len(bboxes)
|
||||
if num_bbox <= 0 and not self.allow_empty:
|
||||
continue
|
||||
elif num_bbox <= 0:
|
||||
is_empty = True
|
||||
|
||||
gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32)
|
||||
gt_class = np.zeros((num_bbox, 1), dtype=np.int32)
|
||||
is_crowd = np.zeros((num_bbox, 1), dtype=np.int32)
|
||||
gt_poly = [None] * num_bbox
|
||||
|
||||
has_segmentation = False
|
||||
for i, box in enumerate(bboxes):
|
||||
catid = box['category_id']
|
||||
gt_class[i][0] = self.catid2clsid[catid]
|
||||
gt_bbox[i, :] = box['clean_bbox']
|
||||
is_crowd[i][0] = box['iscrowd']
|
||||
# check RLE format
|
||||
if 'segmentation' in box and box['iscrowd'] == 1:
|
||||
gt_poly[i] = [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]
|
||||
elif 'segmentation' in box and box['segmentation']:
|
||||
if not np.array(box['segmentation']
|
||||
).size > 0 and not self.allow_empty:
|
||||
bboxes.pop(i)
|
||||
gt_poly.pop(i)
|
||||
np.delete(is_crowd, i)
|
||||
np.delete(gt_class, i)
|
||||
np.delete(gt_bbox, i)
|
||||
else:
|
||||
gt_poly[i] = box['segmentation']
|
||||
has_segmentation = True
|
||||
|
||||
if has_segmentation and not any(
|
||||
gt_poly) and not self.allow_empty:
|
||||
continue
|
||||
|
||||
gt_rec = {
|
||||
'is_crowd': is_crowd,
|
||||
'gt_class': gt_class,
|
||||
'gt_bbox': gt_bbox,
|
||||
'gt_poly': gt_poly,
|
||||
}
|
||||
|
||||
for k, v in gt_rec.items():
|
||||
if k in self.data_fields:
|
||||
coco_rec[k] = v
|
||||
|
||||
# TODO: remove load_semantic
|
||||
if self.load_semantic and 'semantic' in self.data_fields:
|
||||
seg_path = os.path.join(self.dataset_dir, 'stuffthingmaps',
|
||||
'train2017', im_fname[:-3] + 'png')
|
||||
coco_rec.update({'semantic': seg_path})
|
||||
|
||||
logger.debug('Load file: {}, im_id: {}, h: {}, w: {}.'.format(
|
||||
im_path, img_id, im_h, im_w))
|
||||
if is_empty:
|
||||
empty_records.append(coco_rec)
|
||||
else:
|
||||
records.append(coco_rec)
|
||||
ct += 1
|
||||
if self.sample_num > 0 and ct >= self.sample_num:
|
||||
break
|
||||
assert ct > 0, 'not found any coco record in %s' % (anno_path)
|
||||
logger.info('Load [{} samples valid, {} samples invalid] in file {}.'.
|
||||
format(ct, len(img_ids) - ct, anno_path))
|
||||
if self.allow_empty and len(empty_records) > 0:
|
||||
empty_records = self._sample_empty(empty_records, len(records))
|
||||
records += empty_records
|
||||
self.roidbs = records
|
||||
|
||||
if self.supervised:
|
||||
logger.info(f'Use {len(self.roidbs)} sup_samples data as LABELED')
|
||||
else:
|
||||
if self.length > 0: # unsup length will be decide by sup length
|
||||
all_roidbs = self.roidbs.copy()
|
||||
selected_idxs = [
|
||||
np.random.choice(len(all_roidbs))
|
||||
for _ in range(self.length)
|
||||
]
|
||||
self.roidbs = [all_roidbs[i] for i in selected_idxs]
|
||||
logger.info(
|
||||
f'Use {len(self.roidbs)} unsup_samples data as UNLABELED')
|
||||
|
||||
def __getitem__(self, idx):
|
||||
n = len(self.roidbs)
|
||||
if self.repeat > 1:
|
||||
idx %= n
|
||||
# data batch
|
||||
roidb = copy.deepcopy(self.roidbs[idx])
|
||||
if self.mixup_epoch == 0 or self._epoch < self.mixup_epoch:
|
||||
idx = np.random.randint(n)
|
||||
roidb = [roidb, copy.deepcopy(self.roidbs[idx])]
|
||||
elif self.cutmix_epoch == 0 or self._epoch < self.cutmix_epoch:
|
||||
idx = np.random.randint(n)
|
||||
roidb = [roidb, copy.deepcopy(self.roidbs[idx])]
|
||||
elif self.mosaic_epoch == 0 or self._epoch < self.mosaic_epoch:
|
||||
roidb = [roidb, ] + [
|
||||
copy.deepcopy(self.roidbs[np.random.randint(n)])
|
||||
for _ in range(4)
|
||||
]
|
||||
if isinstance(roidb, Sequence):
|
||||
for r in roidb:
|
||||
r['curr_iter'] = self._curr_iter
|
||||
else:
|
||||
roidb['curr_iter'] = self._curr_iter
|
||||
self._curr_iter += 1
|
||||
|
||||
return self.transform(roidb)
|
||||
307
rtdetr_paddle/ppdet/data/source/dataset.py
Normal file
307
rtdetr_paddle/ppdet/data/source/dataset.py
Normal file
@@ -0,0 +1,307 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import copy
|
||||
import numpy as np
|
||||
try:
|
||||
from collections.abc import Sequence
|
||||
except Exception:
|
||||
from collections import Sequence
|
||||
from paddle.io import Dataset
|
||||
from ppdet.core.workspace import register, serializable
|
||||
from ppdet.utils.download import get_dataset_path
|
||||
from ppdet.data import source
|
||||
|
||||
from ppdet.utils.logger import setup_logger
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
@serializable
|
||||
class DetDataset(Dataset):
|
||||
"""
|
||||
Load detection dataset.
|
||||
|
||||
Args:
|
||||
dataset_dir (str): root directory for dataset.
|
||||
image_dir (str): directory for images.
|
||||
anno_path (str): annotation file path.
|
||||
data_fields (list): key name of data dictionary, at least have 'image'.
|
||||
sample_num (int): number of samples to load, -1 means all.
|
||||
use_default_label (bool): whether to load default label list.
|
||||
repeat (int): repeat times for dataset, use in benchmark.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dataset_dir=None,
|
||||
image_dir=None,
|
||||
anno_path=None,
|
||||
data_fields=['image'],
|
||||
sample_num=-1,
|
||||
use_default_label=None,
|
||||
repeat=1,
|
||||
**kwargs):
|
||||
super(DetDataset, self).__init__()
|
||||
self.dataset_dir = dataset_dir if dataset_dir is not None else ''
|
||||
self.anno_path = anno_path
|
||||
self.image_dir = image_dir if image_dir is not None else ''
|
||||
self.data_fields = data_fields
|
||||
self.sample_num = sample_num
|
||||
self.use_default_label = use_default_label
|
||||
self.repeat = repeat
|
||||
self._epoch = 0
|
||||
self._curr_iter = 0
|
||||
|
||||
def __len__(self, ):
|
||||
return len(self.roidbs) * self.repeat
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def __getitem__(self, idx):
|
||||
n = len(self.roidbs)
|
||||
if self.repeat > 1:
|
||||
idx %= n
|
||||
# data batch
|
||||
roidb = copy.deepcopy(self.roidbs[idx])
|
||||
if self.mixup_epoch == 0 or self._epoch < self.mixup_epoch:
|
||||
idx = np.random.randint(n)
|
||||
roidb = [roidb, copy.deepcopy(self.roidbs[idx])]
|
||||
elif self.cutmix_epoch == 0 or self._epoch < self.cutmix_epoch:
|
||||
idx = np.random.randint(n)
|
||||
roidb = [roidb, copy.deepcopy(self.roidbs[idx])]
|
||||
elif self.mosaic_epoch == 0 or self._epoch < self.mosaic_epoch:
|
||||
roidb = [roidb, ] + [
|
||||
copy.deepcopy(self.roidbs[np.random.randint(n)])
|
||||
for _ in range(4)
|
||||
]
|
||||
elif self.pre_img_epoch == 0 or self._epoch < self.pre_img_epoch:
|
||||
# Add previous image as input, only used in CenterTrack
|
||||
idx_pre_img = idx - 1
|
||||
if idx_pre_img < 0:
|
||||
idx_pre_img = idx + 1
|
||||
roidb = [roidb, ] + [copy.deepcopy(self.roidbs[idx_pre_img])]
|
||||
if isinstance(roidb, Sequence):
|
||||
for r in roidb:
|
||||
r['curr_iter'] = self._curr_iter
|
||||
else:
|
||||
roidb['curr_iter'] = self._curr_iter
|
||||
self._curr_iter += 1
|
||||
|
||||
return self.transform(roidb)
|
||||
|
||||
def check_or_download_dataset(self):
|
||||
self.dataset_dir = get_dataset_path(self.dataset_dir, self.anno_path,
|
||||
self.image_dir)
|
||||
|
||||
def set_kwargs(self, **kwargs):
|
||||
self.mixup_epoch = kwargs.get('mixup_epoch', -1)
|
||||
self.cutmix_epoch = kwargs.get('cutmix_epoch', -1)
|
||||
self.mosaic_epoch = kwargs.get('mosaic_epoch', -1)
|
||||
self.pre_img_epoch = kwargs.get('pre_img_epoch', -1)
|
||||
|
||||
def set_transform(self, transform):
|
||||
self.transform = transform
|
||||
|
||||
def set_epoch(self, epoch_id):
|
||||
self._epoch = epoch_id
|
||||
|
||||
def parse_dataset(self, ):
|
||||
raise NotImplementedError(
|
||||
"Need to implement parse_dataset method of Dataset")
|
||||
|
||||
def get_anno(self):
|
||||
if self.anno_path is None:
|
||||
return
|
||||
return os.path.join(self.dataset_dir, self.anno_path)
|
||||
|
||||
|
||||
def _is_valid_file(f, extensions=('.jpg', '.jpeg', '.png', '.bmp')):
|
||||
return f.lower().endswith(extensions)
|
||||
|
||||
|
||||
def _make_dataset(dir):
|
||||
dir = os.path.expanduser(dir)
|
||||
if not os.path.isdir(dir):
|
||||
raise ('{} should be a dir'.format(dir))
|
||||
images = []
|
||||
for root, _, fnames in sorted(os.walk(dir, followlinks=True)):
|
||||
for fname in sorted(fnames):
|
||||
path = os.path.join(root, fname)
|
||||
if _is_valid_file(path):
|
||||
images.append(path)
|
||||
return images
|
||||
|
||||
|
||||
@register
|
||||
@serializable
|
||||
class ImageFolder(DetDataset):
|
||||
def __init__(self,
|
||||
dataset_dir=None,
|
||||
image_dir=None,
|
||||
anno_path=None,
|
||||
sample_num=-1,
|
||||
use_default_label=None,
|
||||
**kwargs):
|
||||
super(ImageFolder, self).__init__(
|
||||
dataset_dir,
|
||||
image_dir,
|
||||
anno_path,
|
||||
sample_num=sample_num,
|
||||
use_default_label=use_default_label)
|
||||
self._imid2path = {}
|
||||
self.roidbs = None
|
||||
self.sample_num = sample_num
|
||||
|
||||
def check_or_download_dataset(self):
|
||||
return
|
||||
|
||||
def get_anno(self):
|
||||
if self.anno_path is None:
|
||||
return
|
||||
if self.dataset_dir:
|
||||
return os.path.join(self.dataset_dir, self.anno_path)
|
||||
else:
|
||||
return self.anno_path
|
||||
|
||||
def parse_dataset(self, ):
|
||||
if not self.roidbs:
|
||||
self.roidbs = self._load_images()
|
||||
|
||||
def _parse(self):
|
||||
image_dir = self.image_dir
|
||||
if not isinstance(image_dir, Sequence):
|
||||
image_dir = [image_dir]
|
||||
images = []
|
||||
for im_dir in image_dir:
|
||||
if os.path.isdir(im_dir):
|
||||
im_dir = os.path.join(self.dataset_dir, im_dir)
|
||||
images.extend(_make_dataset(im_dir))
|
||||
elif os.path.isfile(im_dir) and _is_valid_file(im_dir):
|
||||
images.append(im_dir)
|
||||
return images
|
||||
|
||||
def _load_images(self):
|
||||
images = self._parse()
|
||||
ct = 0
|
||||
records = []
|
||||
for image in images:
|
||||
assert image != '' and os.path.isfile(image), \
|
||||
"Image {} not found".format(image)
|
||||
if self.sample_num > 0 and ct >= self.sample_num:
|
||||
break
|
||||
rec = {'im_id': np.array([ct]), 'im_file': image}
|
||||
self._imid2path[ct] = image
|
||||
ct += 1
|
||||
records.append(rec)
|
||||
assert len(records) > 0, "No image file found"
|
||||
return records
|
||||
|
||||
def get_imid2path(self):
|
||||
return self._imid2path
|
||||
|
||||
def set_images(self, images):
|
||||
self.image_dir = images
|
||||
self.roidbs = self._load_images()
|
||||
|
||||
def set_slice_images(self,
|
||||
images,
|
||||
slice_size=[640, 640],
|
||||
overlap_ratio=[0.25, 0.25]):
|
||||
self.image_dir = images
|
||||
ori_records = self._load_images()
|
||||
try:
|
||||
import sahi
|
||||
from sahi.slicing import slice_image
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'sahi not found, plaese install sahi. '
|
||||
'for example: `pip install sahi`, see https://github.com/obss/sahi.'
|
||||
)
|
||||
raise e
|
||||
|
||||
sub_img_ids = 0
|
||||
ct = 0
|
||||
ct_sub = 0
|
||||
records = []
|
||||
for i, ori_rec in enumerate(ori_records):
|
||||
im_path = ori_rec['im_file']
|
||||
slice_image_result = sahi.slicing.slice_image(
|
||||
image=im_path,
|
||||
slice_height=slice_size[0],
|
||||
slice_width=slice_size[1],
|
||||
overlap_height_ratio=overlap_ratio[0],
|
||||
overlap_width_ratio=overlap_ratio[1])
|
||||
|
||||
sub_img_num = len(slice_image_result)
|
||||
for _ind in range(sub_img_num):
|
||||
im = slice_image_result.images[_ind]
|
||||
rec = {
|
||||
'image': im,
|
||||
'im_id': np.array([sub_img_ids + _ind]),
|
||||
'h': im.shape[0],
|
||||
'w': im.shape[1],
|
||||
'ori_im_id': np.array([ori_rec['im_id'][0]]),
|
||||
'st_pix': np.array(
|
||||
slice_image_result.starting_pixels[_ind],
|
||||
dtype=np.float32),
|
||||
'is_last': 1 if _ind == sub_img_num - 1 else 0,
|
||||
} if 'image' in self.data_fields else {}
|
||||
records.append(rec)
|
||||
ct_sub += sub_img_num
|
||||
ct += 1
|
||||
logger.info('{} samples and slice to {} sub_samples.'.format(ct,
|
||||
ct_sub))
|
||||
self.roidbs = records
|
||||
|
||||
def get_label_list(self):
|
||||
# Only VOC dataset needs label list in ImageFold
|
||||
return self.anno_path
|
||||
|
||||
|
||||
@register
|
||||
class CommonDataset(object):
|
||||
def __init__(self, **dataset_args):
|
||||
super(CommonDataset, self).__init__()
|
||||
dataset_args = copy.deepcopy(dataset_args)
|
||||
type = dataset_args.pop("name")
|
||||
self.dataset = getattr(source, type)(**dataset_args)
|
||||
|
||||
def __call__(self):
|
||||
return self.dataset
|
||||
|
||||
|
||||
@register
|
||||
class TrainDataset(CommonDataset):
|
||||
pass
|
||||
|
||||
|
||||
@register
|
||||
class EvalMOTDataset(CommonDataset):
|
||||
pass
|
||||
|
||||
|
||||
@register
|
||||
class TestMOTDataset(CommonDataset):
|
||||
pass
|
||||
|
||||
|
||||
@register
|
||||
class EvalDataset(CommonDataset):
|
||||
pass
|
||||
|
||||
|
||||
@register
|
||||
class TestDataset(CommonDataset):
|
||||
pass
|
||||
234
rtdetr_paddle/ppdet/data/source/voc.py
Normal file
234
rtdetr_paddle/ppdet/data/source/voc.py
Normal file
@@ -0,0 +1,234 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
from ppdet.core.workspace import register, serializable
|
||||
|
||||
from .dataset import DetDataset
|
||||
|
||||
from ppdet.utils.logger import setup_logger
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
@register
|
||||
@serializable
|
||||
class VOCDataSet(DetDataset):
|
||||
"""
|
||||
Load dataset with PascalVOC format.
|
||||
|
||||
Notes:
|
||||
`anno_path` must contains xml file and image file path for annotations.
|
||||
|
||||
Args:
|
||||
dataset_dir (str): root directory for dataset.
|
||||
image_dir (str): directory for images.
|
||||
anno_path (str): voc annotation file path.
|
||||
data_fields (list): key name of data dictionary, at least have 'image'.
|
||||
sample_num (int): number of samples to load, -1 means all.
|
||||
label_list (str): if use_default_label is False, will load
|
||||
mapping between category and class index.
|
||||
allow_empty (bool): whether to load empty entry. False as default
|
||||
empty_ratio (float): the ratio of empty record number to total
|
||||
record's, if empty_ratio is out of [0. ,1.), do not sample the
|
||||
records and use all the empty entries. 1. as default
|
||||
repeat (int): repeat times for dataset, use in benchmark.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dataset_dir=None,
|
||||
image_dir=None,
|
||||
anno_path=None,
|
||||
data_fields=['image'],
|
||||
sample_num=-1,
|
||||
label_list=None,
|
||||
allow_empty=False,
|
||||
empty_ratio=1.,
|
||||
repeat=1):
|
||||
super(VOCDataSet, self).__init__(
|
||||
dataset_dir=dataset_dir,
|
||||
image_dir=image_dir,
|
||||
anno_path=anno_path,
|
||||
data_fields=data_fields,
|
||||
sample_num=sample_num,
|
||||
repeat=repeat)
|
||||
self.label_list = label_list
|
||||
self.allow_empty = allow_empty
|
||||
self.empty_ratio = empty_ratio
|
||||
|
||||
def _sample_empty(self, records, num):
|
||||
# if empty_ratio is out of [0. ,1.), do not sample the records
|
||||
if self.empty_ratio < 0. or self.empty_ratio >= 1.:
|
||||
return records
|
||||
import random
|
||||
sample_num = min(
|
||||
int(num * self.empty_ratio / (1 - self.empty_ratio)), len(records))
|
||||
records = random.sample(records, sample_num)
|
||||
return records
|
||||
|
||||
def parse_dataset(self, ):
|
||||
anno_path = os.path.join(self.dataset_dir, self.anno_path)
|
||||
image_dir = os.path.join(self.dataset_dir, self.image_dir)
|
||||
|
||||
# mapping category name to class id
|
||||
# first_class:0, second_class:1, ...
|
||||
records = []
|
||||
empty_records = []
|
||||
ct = 0
|
||||
cname2cid = {}
|
||||
if self.label_list:
|
||||
label_path = os.path.join(self.dataset_dir, self.label_list)
|
||||
if not os.path.exists(label_path):
|
||||
raise ValueError("label_list {} does not exists".format(
|
||||
label_path))
|
||||
with open(label_path, 'r') as fr:
|
||||
label_id = 0
|
||||
for line in fr.readlines():
|
||||
cname2cid[line.strip()] = label_id
|
||||
label_id += 1
|
||||
else:
|
||||
cname2cid = pascalvoc_label()
|
||||
|
||||
with open(anno_path, 'r') as fr:
|
||||
while True:
|
||||
line = fr.readline()
|
||||
if not line:
|
||||
break
|
||||
img_file, xml_file = [os.path.join(image_dir, x) \
|
||||
for x in line.strip().split()[:2]]
|
||||
if not os.path.exists(img_file):
|
||||
logger.warning(
|
||||
'Illegal image file: {}, and it will be ignored'.format(
|
||||
img_file))
|
||||
continue
|
||||
if not os.path.isfile(xml_file):
|
||||
logger.warning(
|
||||
'Illegal xml file: {}, and it will be ignored'.format(
|
||||
xml_file))
|
||||
continue
|
||||
tree = ET.parse(xml_file)
|
||||
if tree.find('id') is None:
|
||||
im_id = np.array([ct])
|
||||
else:
|
||||
im_id = np.array([int(tree.find('id').text)])
|
||||
|
||||
objs = tree.findall('object')
|
||||
im_w = float(tree.find('size').find('width').text)
|
||||
im_h = float(tree.find('size').find('height').text)
|
||||
if im_w < 0 or im_h < 0:
|
||||
logger.warning(
|
||||
'Illegal width: {} or height: {} in annotation, '
|
||||
'and {} will be ignored'.format(im_w, im_h, xml_file))
|
||||
continue
|
||||
|
||||
num_bbox, i = len(objs), 0
|
||||
gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32)
|
||||
gt_class = np.zeros((num_bbox, 1), dtype=np.int32)
|
||||
gt_score = np.zeros((num_bbox, 1), dtype=np.float32)
|
||||
difficult = np.zeros((num_bbox, 1), dtype=np.int32)
|
||||
for obj in objs:
|
||||
cname = obj.find('name').text
|
||||
|
||||
# user dataset may not contain difficult field
|
||||
_difficult = obj.find('difficult')
|
||||
_difficult = int(
|
||||
_difficult.text) if _difficult is not None else 0
|
||||
|
||||
x1 = float(obj.find('bndbox').find('xmin').text)
|
||||
y1 = float(obj.find('bndbox').find('ymin').text)
|
||||
x2 = float(obj.find('bndbox').find('xmax').text)
|
||||
y2 = float(obj.find('bndbox').find('ymax').text)
|
||||
x1 = max(0, x1)
|
||||
y1 = max(0, y1)
|
||||
x2 = min(im_w - 1, x2)
|
||||
y2 = min(im_h - 1, y2)
|
||||
if x2 > x1 and y2 > y1:
|
||||
gt_bbox[i, :] = [x1, y1, x2, y2]
|
||||
gt_class[i, 0] = cname2cid[cname]
|
||||
gt_score[i, 0] = 1.
|
||||
difficult[i, 0] = _difficult
|
||||
i += 1
|
||||
else:
|
||||
logger.warning(
|
||||
'Found an invalid bbox in annotations: xml_file: {}'
|
||||
', x1: {}, y1: {}, x2: {}, y2: {}.'.format(
|
||||
xml_file, x1, y1, x2, y2))
|
||||
gt_bbox = gt_bbox[:i, :]
|
||||
gt_class = gt_class[:i, :]
|
||||
gt_score = gt_score[:i, :]
|
||||
difficult = difficult[:i, :]
|
||||
|
||||
voc_rec = {
|
||||
'im_file': img_file,
|
||||
'im_id': im_id,
|
||||
'h': im_h,
|
||||
'w': im_w
|
||||
} if 'image' in self.data_fields else {}
|
||||
|
||||
gt_rec = {
|
||||
'gt_class': gt_class,
|
||||
'gt_score': gt_score,
|
||||
'gt_bbox': gt_bbox,
|
||||
'difficult': difficult
|
||||
}
|
||||
for k, v in gt_rec.items():
|
||||
if k in self.data_fields:
|
||||
voc_rec[k] = v
|
||||
|
||||
if len(objs) == 0:
|
||||
empty_records.append(voc_rec)
|
||||
else:
|
||||
records.append(voc_rec)
|
||||
|
||||
ct += 1
|
||||
if self.sample_num > 0 and ct >= self.sample_num:
|
||||
break
|
||||
assert ct > 0, 'not found any voc record in %s' % (self.anno_path)
|
||||
logger.debug('{} samples in file {}'.format(ct, anno_path))
|
||||
if self.allow_empty and len(empty_records) > 0:
|
||||
empty_records = self._sample_empty(empty_records, len(records))
|
||||
records += empty_records
|
||||
self.roidbs, self.cname2cid = records, cname2cid
|
||||
|
||||
def get_label_list(self):
|
||||
return os.path.join(self.dataset_dir, self.label_list)
|
||||
|
||||
|
||||
def pascalvoc_label():
|
||||
labels_map = {
|
||||
'aeroplane': 0,
|
||||
'bicycle': 1,
|
||||
'bird': 2,
|
||||
'boat': 3,
|
||||
'bottle': 4,
|
||||
'bus': 5,
|
||||
'car': 6,
|
||||
'cat': 7,
|
||||
'chair': 8,
|
||||
'cow': 9,
|
||||
'diningtable': 10,
|
||||
'dog': 11,
|
||||
'horse': 12,
|
||||
'motorbike': 13,
|
||||
'person': 14,
|
||||
'pottedplant': 15,
|
||||
'sheep': 16,
|
||||
'sofa': 17,
|
||||
'train': 18,
|
||||
'tvmonitor': 19
|
||||
}
|
||||
return labels_map
|
||||
25
rtdetr_paddle/ppdet/data/transform/__init__.py
Normal file
25
rtdetr_paddle/ppdet/data/transform/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import operators
|
||||
from . import batch_operators
|
||||
|
||||
|
||||
from .operators import *
|
||||
from .batch_operators import *
|
||||
|
||||
|
||||
__all__ = []
|
||||
__all__ += registered_ops
|
||||
|
||||
322
rtdetr_paddle/ppdet/data/transform/batch_operators.py
Normal file
322
rtdetr_paddle/ppdet/data/transform/batch_operators.py
Normal file
@@ -0,0 +1,322 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import typing
|
||||
|
||||
try:
|
||||
from collections.abc import Sequence
|
||||
except Exception:
|
||||
from collections import Sequence
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from .operators import register_op, BaseOperator, Resize
|
||||
from ppdet.utils.logger import setup_logger
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
__all__ = [
|
||||
'PadBatch',
|
||||
'BatchRandomResize',
|
||||
'PadGT',
|
||||
]
|
||||
|
||||
|
||||
@register_op
|
||||
class PadBatch(BaseOperator):
|
||||
"""
|
||||
Pad a batch of samples so they can be divisible by a stride.
|
||||
The layout of each image should be 'CHW'.
|
||||
Args:
|
||||
pad_to_stride (int): If `pad_to_stride > 0`, pad zeros to ensure
|
||||
height and width is divisible by `pad_to_stride`.
|
||||
"""
|
||||
|
||||
def __init__(self, pad_to_stride=0):
|
||||
super(PadBatch, self).__init__()
|
||||
self.pad_to_stride = pad_to_stride
|
||||
|
||||
def __call__(self, samples, context=None):
|
||||
"""
|
||||
Args:
|
||||
samples (list): a batch of sample, each is dict.
|
||||
"""
|
||||
coarsest_stride = self.pad_to_stride
|
||||
|
||||
# multi scale input is nested list
|
||||
if isinstance(samples,
|
||||
typing.Sequence) and len(samples) > 0 and isinstance(
|
||||
samples[0], typing.Sequence):
|
||||
inner_samples = samples[0]
|
||||
else:
|
||||
inner_samples = samples
|
||||
|
||||
max_shape = np.array(
|
||||
[data['image'].shape for data in inner_samples]).max(axis=0)
|
||||
if coarsest_stride > 0:
|
||||
max_shape[1] = int(
|
||||
np.ceil(max_shape[1] / coarsest_stride) * coarsest_stride)
|
||||
max_shape[2] = int(
|
||||
np.ceil(max_shape[2] / coarsest_stride) * coarsest_stride)
|
||||
|
||||
for data in inner_samples:
|
||||
im = data['image']
|
||||
im_c, im_h, im_w = im.shape[:]
|
||||
padding_im = np.zeros(
|
||||
(im_c, max_shape[1], max_shape[2]), dtype=np.float32)
|
||||
padding_im[:, :im_h, :im_w] = im
|
||||
data['image'] = padding_im
|
||||
if 'semantic' in data and data['semantic'] is not None:
|
||||
semantic = data['semantic']
|
||||
padding_sem = np.zeros(
|
||||
(1, max_shape[1], max_shape[2]), dtype=np.float32)
|
||||
padding_sem[:, :im_h, :im_w] = semantic
|
||||
data['semantic'] = padding_sem
|
||||
if 'gt_segm' in data and data['gt_segm'] is not None:
|
||||
gt_segm = data['gt_segm']
|
||||
padding_segm = np.zeros(
|
||||
(gt_segm.shape[0], max_shape[1], max_shape[2]),
|
||||
dtype=np.uint8)
|
||||
padding_segm[:, :im_h, :im_w] = gt_segm
|
||||
data['gt_segm'] = padding_segm
|
||||
|
||||
return samples
|
||||
|
||||
|
||||
@register_op
|
||||
class BatchRandomResize(BaseOperator):
|
||||
"""
|
||||
Resize image to target size randomly. random target_size and interpolation method
|
||||
Args:
|
||||
target_size (int, list, tuple): image target size, if random size is True, must be list or tuple
|
||||
keep_ratio (bool): whether keep_raio or not, default true
|
||||
interp (int): the interpolation method
|
||||
random_size (bool): whether random select target size of image
|
||||
random_interp (bool): whether random select interpolation method
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
target_size,
|
||||
keep_ratio,
|
||||
interp=cv2.INTER_NEAREST,
|
||||
random_size=True,
|
||||
random_interp=False):
|
||||
super(BatchRandomResize, self).__init__()
|
||||
self.keep_ratio = keep_ratio
|
||||
self.interps = [
|
||||
cv2.INTER_NEAREST,
|
||||
cv2.INTER_LINEAR,
|
||||
cv2.INTER_AREA,
|
||||
cv2.INTER_CUBIC,
|
||||
cv2.INTER_LANCZOS4,
|
||||
]
|
||||
self.interp = interp
|
||||
assert isinstance(target_size, (
|
||||
int, Sequence)), "target_size must be int, list or tuple"
|
||||
if random_size and not isinstance(target_size, list):
|
||||
raise TypeError(
|
||||
"Type of target_size is invalid when random_size is True. Must be List, now is {}".
|
||||
format(type(target_size)))
|
||||
self.target_size = target_size
|
||||
self.random_size = random_size
|
||||
self.random_interp = random_interp
|
||||
|
||||
def __call__(self, samples, context=None):
|
||||
if self.random_size:
|
||||
index = np.random.choice(len(self.target_size))
|
||||
target_size = self.target_size[index]
|
||||
else:
|
||||
target_size = self.target_size
|
||||
|
||||
if self.random_interp:
|
||||
interp = np.random.choice(self.interps)
|
||||
else:
|
||||
interp = self.interp
|
||||
|
||||
resizer = Resize(target_size, keep_ratio=self.keep_ratio, interp=interp)
|
||||
return resizer(samples, context=context)
|
||||
|
||||
|
||||
@register_op
|
||||
class PadGT(BaseOperator):
|
||||
"""
|
||||
Pad 0 to `gt_class`, `gt_bbox`, `gt_score`...
|
||||
The num_max_boxes is the largest for batch.
|
||||
Args:
|
||||
return_gt_mask (bool): If true, return `pad_gt_mask`,
|
||||
1 means bbox, 0 means no bbox.
|
||||
"""
|
||||
|
||||
def __init__(self, return_gt_mask=True, pad_img=False, minimum_gtnum=0):
|
||||
super(PadGT, self).__init__()
|
||||
self.return_gt_mask = return_gt_mask
|
||||
self.pad_img = pad_img
|
||||
self.minimum_gtnum = minimum_gtnum
|
||||
|
||||
def _impad(self,
|
||||
img: np.ndarray,
|
||||
*,
|
||||
shape=None,
|
||||
padding=None,
|
||||
pad_val=0,
|
||||
padding_mode='constant') -> np.ndarray:
|
||||
"""Pad the given image to a certain shape or pad on all sides with
|
||||
specified padding mode and padding value.
|
||||
|
||||
Args:
|
||||
img (ndarray): Image to be padded.
|
||||
shape (tuple[int]): Expected padding shape (h, w). Default: None.
|
||||
padding (int or tuple[int]): Padding on each border. If a single int is
|
||||
provided this is used to pad all borders. If tuple of length 2 is
|
||||
provided this is the padding on left/right and top/bottom
|
||||
respectively. If a tuple of length 4 is provided this is the
|
||||
padding for the left, top, right and bottom borders respectively.
|
||||
Default: None. Note that `shape` and `padding` can not be both
|
||||
set.
|
||||
pad_val (Number | Sequence[Number]): Values to be filled in padding
|
||||
areas when padding_mode is 'constant'. Default: 0.
|
||||
padding_mode (str): Type of padding. Should be: constant, edge,
|
||||
reflect or symmetric. Default: constant.
|
||||
- constant: pads with a constant value, this value is specified
|
||||
with pad_val.
|
||||
- edge: pads with the last value at the edge of the image.
|
||||
- reflect: pads with reflection of image without repeating the last
|
||||
value on the edge. For example, padding [1, 2, 3, 4] with 2
|
||||
elements on both sides in reflect mode will result in
|
||||
[3, 2, 1, 2, 3, 4, 3, 2].
|
||||
- symmetric: pads with reflection of image repeating the last value
|
||||
on the edge. For example, padding [1, 2, 3, 4] with 2 elements on
|
||||
both sides in symmetric mode will result in
|
||||
[2, 1, 1, 2, 3, 4, 4, 3]
|
||||
|
||||
Returns:
|
||||
ndarray: The padded image.
|
||||
"""
|
||||
|
||||
assert (shape is not None) ^ (padding is not None)
|
||||
if shape is not None:
|
||||
width = max(shape[1] - img.shape[1], 0)
|
||||
height = max(shape[0] - img.shape[0], 0)
|
||||
padding = (0, 0, int(width), int(height))
|
||||
|
||||
# check pad_val
|
||||
import numbers
|
||||
if isinstance(pad_val, tuple):
|
||||
assert len(pad_val) == img.shape[-1]
|
||||
elif not isinstance(pad_val, numbers.Number):
|
||||
raise TypeError('pad_val must be a int or a tuple. '
|
||||
f'But received {type(pad_val)}')
|
||||
|
||||
# check padding
|
||||
if isinstance(padding, tuple) and len(padding) in [2, 4]:
|
||||
if len(padding) == 2:
|
||||
padding = (padding[0], padding[1], padding[0], padding[1])
|
||||
elif isinstance(padding, numbers.Number):
|
||||
padding = (padding, padding, padding, padding)
|
||||
else:
|
||||
raise ValueError('Padding must be a int or a 2, or 4 element tuple.'
|
||||
f'But received {padding}')
|
||||
|
||||
# check padding mode
|
||||
assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric']
|
||||
|
||||
border_type = {
|
||||
'constant': cv2.BORDER_CONSTANT,
|
||||
'edge': cv2.BORDER_REPLICATE,
|
||||
'reflect': cv2.BORDER_REFLECT_101,
|
||||
'symmetric': cv2.BORDER_REFLECT
|
||||
}
|
||||
img = cv2.copyMakeBorder(
|
||||
img,
|
||||
padding[1],
|
||||
padding[3],
|
||||
padding[0],
|
||||
padding[2],
|
||||
border_type[padding_mode],
|
||||
value=pad_val)
|
||||
|
||||
return img
|
||||
|
||||
def checkmaxshape(self, samples):
|
||||
maxh, maxw = 0, 0
|
||||
for sample in samples:
|
||||
h, w = sample['im_shape']
|
||||
if h > maxh:
|
||||
maxh = h
|
||||
if w > maxw:
|
||||
maxw = w
|
||||
return (maxh, maxw)
|
||||
|
||||
def __call__(self, samples, context=None):
|
||||
num_max_boxes = max([len(s['gt_bbox']) for s in samples])
|
||||
num_max_boxes = max(self.minimum_gtnum, num_max_boxes)
|
||||
if self.pad_img:
|
||||
maxshape = self.checkmaxshape(samples)
|
||||
for sample in samples:
|
||||
if self.pad_img:
|
||||
img = sample['image']
|
||||
padimg = self._impad(img, shape=maxshape)
|
||||
sample['image'] = padimg
|
||||
if self.return_gt_mask:
|
||||
sample['pad_gt_mask'] = np.zeros(
|
||||
(num_max_boxes, 1), dtype=np.float32)
|
||||
if num_max_boxes == 0:
|
||||
continue
|
||||
|
||||
num_gt = len(sample['gt_bbox'])
|
||||
pad_gt_class = np.zeros((num_max_boxes, 1), dtype=np.int32)
|
||||
pad_gt_bbox = np.zeros((num_max_boxes, 4), dtype=np.float32)
|
||||
if num_gt > 0:
|
||||
pad_gt_class[:num_gt] = sample['gt_class']
|
||||
pad_gt_bbox[:num_gt] = sample['gt_bbox']
|
||||
sample['gt_class'] = pad_gt_class
|
||||
sample['gt_bbox'] = pad_gt_bbox
|
||||
# pad_gt_mask
|
||||
if 'pad_gt_mask' in sample:
|
||||
sample['pad_gt_mask'][:num_gt] = 1
|
||||
# gt_score
|
||||
if 'gt_score' in sample:
|
||||
pad_gt_score = np.zeros((num_max_boxes, 1), dtype=np.float32)
|
||||
if num_gt > 0:
|
||||
pad_gt_score[:num_gt] = sample['gt_score']
|
||||
sample['gt_score'] = pad_gt_score
|
||||
if 'is_crowd' in sample:
|
||||
pad_is_crowd = np.zeros((num_max_boxes, 1), dtype=np.int32)
|
||||
if num_gt > 0:
|
||||
pad_is_crowd[:num_gt] = sample['is_crowd']
|
||||
sample['is_crowd'] = pad_is_crowd
|
||||
if 'difficult' in sample:
|
||||
pad_diff = np.zeros((num_max_boxes, 1), dtype=np.int32)
|
||||
if num_gt > 0:
|
||||
pad_diff[:num_gt] = sample['difficult']
|
||||
sample['difficult'] = pad_diff
|
||||
if 'gt_joints' in sample:
|
||||
num_joints = sample['gt_joints'].shape[1]
|
||||
pad_gt_joints = np.zeros(
|
||||
(num_max_boxes, num_joints, 3), dtype=np.float32)
|
||||
if num_gt > 0:
|
||||
pad_gt_joints[:num_gt] = sample['gt_joints']
|
||||
sample['gt_joints'] = pad_gt_joints
|
||||
if 'gt_areas' in sample:
|
||||
pad_gt_areas = np.zeros((num_max_boxes, 1), dtype=np.float32)
|
||||
if num_gt > 0:
|
||||
pad_gt_areas[:num_gt, 0] = sample['gt_areas']
|
||||
sample['gt_areas'] = pad_gt_areas
|
||||
return samples
|
||||
|
||||
|
||||
|
||||
494
rtdetr_paddle/ppdet/data/transform/op_helper.py
Normal file
494
rtdetr_paddle/ppdet/data/transform/op_helper.py
Normal file
@@ -0,0 +1,494 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# this file contains helper methods for BBOX processing
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import random
|
||||
import math
|
||||
import cv2
|
||||
|
||||
|
||||
def meet_emit_constraint(src_bbox, sample_bbox):
|
||||
center_x = (src_bbox[2] + src_bbox[0]) / 2
|
||||
center_y = (src_bbox[3] + src_bbox[1]) / 2
|
||||
if center_x >= sample_bbox[0] and \
|
||||
center_x <= sample_bbox[2] and \
|
||||
center_y >= sample_bbox[1] and \
|
||||
center_y <= sample_bbox[3]:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def clip_bbox(src_bbox):
|
||||
src_bbox[0] = max(min(src_bbox[0], 1.0), 0.0)
|
||||
src_bbox[1] = max(min(src_bbox[1], 1.0), 0.0)
|
||||
src_bbox[2] = max(min(src_bbox[2], 1.0), 0.0)
|
||||
src_bbox[3] = max(min(src_bbox[3], 1.0), 0.0)
|
||||
return src_bbox
|
||||
|
||||
|
||||
def bbox_area(src_bbox):
|
||||
if src_bbox[2] < src_bbox[0] or src_bbox[3] < src_bbox[1]:
|
||||
return 0.
|
||||
else:
|
||||
width = src_bbox[2] - src_bbox[0]
|
||||
height = src_bbox[3] - src_bbox[1]
|
||||
return width * height
|
||||
|
||||
|
||||
def is_overlap(object_bbox, sample_bbox):
|
||||
if object_bbox[0] >= sample_bbox[2] or \
|
||||
object_bbox[2] <= sample_bbox[0] or \
|
||||
object_bbox[1] >= sample_bbox[3] or \
|
||||
object_bbox[3] <= sample_bbox[1]:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def filter_and_process(sample_bbox, bboxes, labels, scores=None,
|
||||
keypoints=None):
|
||||
new_bboxes = []
|
||||
new_labels = []
|
||||
new_scores = []
|
||||
new_keypoints = []
|
||||
new_kp_ignore = []
|
||||
for i in range(len(bboxes)):
|
||||
new_bbox = [0, 0, 0, 0]
|
||||
obj_bbox = [bboxes[i][0], bboxes[i][1], bboxes[i][2], bboxes[i][3]]
|
||||
if not meet_emit_constraint(obj_bbox, sample_bbox):
|
||||
continue
|
||||
if not is_overlap(obj_bbox, sample_bbox):
|
||||
continue
|
||||
sample_width = sample_bbox[2] - sample_bbox[0]
|
||||
sample_height = sample_bbox[3] - sample_bbox[1]
|
||||
new_bbox[0] = (obj_bbox[0] - sample_bbox[0]) / sample_width
|
||||
new_bbox[1] = (obj_bbox[1] - sample_bbox[1]) / sample_height
|
||||
new_bbox[2] = (obj_bbox[2] - sample_bbox[0]) / sample_width
|
||||
new_bbox[3] = (obj_bbox[3] - sample_bbox[1]) / sample_height
|
||||
new_bbox = clip_bbox(new_bbox)
|
||||
if bbox_area(new_bbox) > 0:
|
||||
new_bboxes.append(new_bbox)
|
||||
new_labels.append([labels[i][0]])
|
||||
if scores is not None:
|
||||
new_scores.append([scores[i][0]])
|
||||
if keypoints is not None:
|
||||
sample_keypoint = keypoints[0][i]
|
||||
for j in range(len(sample_keypoint)):
|
||||
kp_len = sample_height if j % 2 else sample_width
|
||||
sample_coord = sample_bbox[1] if j % 2 else sample_bbox[0]
|
||||
sample_keypoint[j] = (
|
||||
sample_keypoint[j] - sample_coord) / kp_len
|
||||
sample_keypoint[j] = max(min(sample_keypoint[j], 1.0), 0.0)
|
||||
new_keypoints.append(sample_keypoint)
|
||||
new_kp_ignore.append(keypoints[1][i])
|
||||
|
||||
bboxes = np.array(new_bboxes)
|
||||
labels = np.array(new_labels)
|
||||
scores = np.array(new_scores)
|
||||
if keypoints is not None:
|
||||
keypoints = np.array(new_keypoints)
|
||||
new_kp_ignore = np.array(new_kp_ignore)
|
||||
return bboxes, labels, scores, (keypoints, new_kp_ignore)
|
||||
return bboxes, labels, scores
|
||||
|
||||
|
||||
def bbox_area_sampling(bboxes, labels, scores, target_size, min_size):
|
||||
new_bboxes = []
|
||||
new_labels = []
|
||||
new_scores = []
|
||||
for i, bbox in enumerate(bboxes):
|
||||
w = float((bbox[2] - bbox[0]) * target_size)
|
||||
h = float((bbox[3] - bbox[1]) * target_size)
|
||||
if w * h < float(min_size * min_size):
|
||||
continue
|
||||
else:
|
||||
new_bboxes.append(bbox)
|
||||
new_labels.append(labels[i])
|
||||
if scores is not None and scores.size != 0:
|
||||
new_scores.append(scores[i])
|
||||
bboxes = np.array(new_bboxes)
|
||||
labels = np.array(new_labels)
|
||||
scores = np.array(new_scores)
|
||||
return bboxes, labels, scores
|
||||
|
||||
|
||||
def generate_sample_bbox(sampler):
|
||||
scale = np.random.uniform(sampler[2], sampler[3])
|
||||
aspect_ratio = np.random.uniform(sampler[4], sampler[5])
|
||||
aspect_ratio = max(aspect_ratio, (scale**2.0))
|
||||
aspect_ratio = min(aspect_ratio, 1 / (scale**2.0))
|
||||
bbox_width = scale * (aspect_ratio**0.5)
|
||||
bbox_height = scale / (aspect_ratio**0.5)
|
||||
xmin_bound = 1 - bbox_width
|
||||
ymin_bound = 1 - bbox_height
|
||||
xmin = np.random.uniform(0, xmin_bound)
|
||||
ymin = np.random.uniform(0, ymin_bound)
|
||||
xmax = xmin + bbox_width
|
||||
ymax = ymin + bbox_height
|
||||
sampled_bbox = [xmin, ymin, xmax, ymax]
|
||||
return sampled_bbox
|
||||
|
||||
|
||||
def generate_sample_bbox_square(sampler, image_width, image_height):
|
||||
scale = np.random.uniform(sampler[2], sampler[3])
|
||||
aspect_ratio = np.random.uniform(sampler[4], sampler[5])
|
||||
aspect_ratio = max(aspect_ratio, (scale**2.0))
|
||||
aspect_ratio = min(aspect_ratio, 1 / (scale**2.0))
|
||||
bbox_width = scale * (aspect_ratio**0.5)
|
||||
bbox_height = scale / (aspect_ratio**0.5)
|
||||
if image_height < image_width:
|
||||
bbox_width = bbox_height * image_height / image_width
|
||||
else:
|
||||
bbox_height = bbox_width * image_width / image_height
|
||||
xmin_bound = 1 - bbox_width
|
||||
ymin_bound = 1 - bbox_height
|
||||
xmin = np.random.uniform(0, xmin_bound)
|
||||
ymin = np.random.uniform(0, ymin_bound)
|
||||
xmax = xmin + bbox_width
|
||||
ymax = ymin + bbox_height
|
||||
sampled_bbox = [xmin, ymin, xmax, ymax]
|
||||
return sampled_bbox
|
||||
|
||||
|
||||
def data_anchor_sampling(bbox_labels, image_width, image_height, scale_array,
|
||||
resize_width):
|
||||
num_gt = len(bbox_labels)
|
||||
# np.random.randint range: [low, high)
|
||||
rand_idx = np.random.randint(0, num_gt) if num_gt != 0 else 0
|
||||
|
||||
if num_gt != 0:
|
||||
norm_xmin = bbox_labels[rand_idx][0]
|
||||
norm_ymin = bbox_labels[rand_idx][1]
|
||||
norm_xmax = bbox_labels[rand_idx][2]
|
||||
norm_ymax = bbox_labels[rand_idx][3]
|
||||
|
||||
xmin = norm_xmin * image_width
|
||||
ymin = norm_ymin * image_height
|
||||
wid = image_width * (norm_xmax - norm_xmin)
|
||||
hei = image_height * (norm_ymax - norm_ymin)
|
||||
range_size = 0
|
||||
|
||||
area = wid * hei
|
||||
for scale_ind in range(0, len(scale_array) - 1):
|
||||
if area > scale_array[scale_ind] ** 2 and area < \
|
||||
scale_array[scale_ind + 1] ** 2:
|
||||
range_size = scale_ind + 1
|
||||
break
|
||||
|
||||
if area > scale_array[len(scale_array) - 2]**2:
|
||||
range_size = len(scale_array) - 2
|
||||
|
||||
scale_choose = 0.0
|
||||
if range_size == 0:
|
||||
rand_idx_size = 0
|
||||
else:
|
||||
# np.random.randint range: [low, high)
|
||||
rng_rand_size = np.random.randint(0, range_size + 1)
|
||||
rand_idx_size = rng_rand_size % (range_size + 1)
|
||||
|
||||
if rand_idx_size == range_size:
|
||||
min_resize_val = scale_array[rand_idx_size] / 2.0
|
||||
max_resize_val = min(2.0 * scale_array[rand_idx_size],
|
||||
2 * math.sqrt(wid * hei))
|
||||
scale_choose = random.uniform(min_resize_val, max_resize_val)
|
||||
else:
|
||||
min_resize_val = scale_array[rand_idx_size] / 2.0
|
||||
max_resize_val = 2.0 * scale_array[rand_idx_size]
|
||||
scale_choose = random.uniform(min_resize_val, max_resize_val)
|
||||
|
||||
sample_bbox_size = wid * resize_width / scale_choose
|
||||
|
||||
w_off_orig = 0.0
|
||||
h_off_orig = 0.0
|
||||
if sample_bbox_size < max(image_height, image_width):
|
||||
if wid <= sample_bbox_size:
|
||||
w_off_orig = np.random.uniform(xmin + wid - sample_bbox_size,
|
||||
xmin)
|
||||
else:
|
||||
w_off_orig = np.random.uniform(xmin,
|
||||
xmin + wid - sample_bbox_size)
|
||||
|
||||
if hei <= sample_bbox_size:
|
||||
h_off_orig = np.random.uniform(ymin + hei - sample_bbox_size,
|
||||
ymin)
|
||||
else:
|
||||
h_off_orig = np.random.uniform(ymin,
|
||||
ymin + hei - sample_bbox_size)
|
||||
|
||||
else:
|
||||
w_off_orig = np.random.uniform(image_width - sample_bbox_size, 0.0)
|
||||
h_off_orig = np.random.uniform(image_height - sample_bbox_size, 0.0)
|
||||
|
||||
w_off_orig = math.floor(w_off_orig)
|
||||
h_off_orig = math.floor(h_off_orig)
|
||||
|
||||
# Figure out top left coordinates.
|
||||
w_off = float(w_off_orig / image_width)
|
||||
h_off = float(h_off_orig / image_height)
|
||||
|
||||
sampled_bbox = [
|
||||
w_off, h_off, w_off + float(sample_bbox_size / image_width),
|
||||
h_off + float(sample_bbox_size / image_height)
|
||||
]
|
||||
return sampled_bbox
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
def jaccard_overlap(sample_bbox, object_bbox):
|
||||
if sample_bbox[0] >= object_bbox[2] or \
|
||||
sample_bbox[2] <= object_bbox[0] or \
|
||||
sample_bbox[1] >= object_bbox[3] or \
|
||||
sample_bbox[3] <= object_bbox[1]:
|
||||
return 0
|
||||
intersect_xmin = max(sample_bbox[0], object_bbox[0])
|
||||
intersect_ymin = max(sample_bbox[1], object_bbox[1])
|
||||
intersect_xmax = min(sample_bbox[2], object_bbox[2])
|
||||
intersect_ymax = min(sample_bbox[3], object_bbox[3])
|
||||
intersect_size = (intersect_xmax - intersect_xmin) * (
|
||||
intersect_ymax - intersect_ymin)
|
||||
sample_bbox_size = bbox_area(sample_bbox)
|
||||
object_bbox_size = bbox_area(object_bbox)
|
||||
overlap = intersect_size / (
|
||||
sample_bbox_size + object_bbox_size - intersect_size)
|
||||
return overlap
|
||||
|
||||
|
||||
def intersect_bbox(bbox1, bbox2):
|
||||
if bbox2[0] > bbox1[2] or bbox2[2] < bbox1[0] or \
|
||||
bbox2[1] > bbox1[3] or bbox2[3] < bbox1[1]:
|
||||
intersection_box = [0.0, 0.0, 0.0, 0.0]
|
||||
else:
|
||||
intersection_box = [
|
||||
max(bbox1[0], bbox2[0]), max(bbox1[1], bbox2[1]),
|
||||
min(bbox1[2], bbox2[2]), min(bbox1[3], bbox2[3])
|
||||
]
|
||||
return intersection_box
|
||||
|
||||
|
||||
def bbox_coverage(bbox1, bbox2):
|
||||
inter_box = intersect_bbox(bbox1, bbox2)
|
||||
intersect_size = bbox_area(inter_box)
|
||||
|
||||
if intersect_size > 0:
|
||||
bbox1_size = bbox_area(bbox1)
|
||||
return intersect_size / bbox1_size
|
||||
else:
|
||||
return 0.
|
||||
|
||||
|
||||
def satisfy_sample_constraint(sampler,
|
||||
sample_bbox,
|
||||
gt_bboxes,
|
||||
satisfy_all=False):
|
||||
if sampler[6] == 0 and sampler[7] == 0:
|
||||
return True
|
||||
satisfied = []
|
||||
for i in range(len(gt_bboxes)):
|
||||
object_bbox = [
|
||||
gt_bboxes[i][0], gt_bboxes[i][1], gt_bboxes[i][2], gt_bboxes[i][3]
|
||||
]
|
||||
overlap = jaccard_overlap(sample_bbox, object_bbox)
|
||||
if sampler[6] != 0 and \
|
||||
overlap < sampler[6]:
|
||||
satisfied.append(False)
|
||||
continue
|
||||
if sampler[7] != 0 and \
|
||||
overlap > sampler[7]:
|
||||
satisfied.append(False)
|
||||
continue
|
||||
satisfied.append(True)
|
||||
if not satisfy_all:
|
||||
return True
|
||||
|
||||
if satisfy_all:
|
||||
return np.all(satisfied)
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def satisfy_sample_constraint_coverage(sampler, sample_bbox, gt_bboxes):
|
||||
if sampler[6] == 0 and sampler[7] == 0:
|
||||
has_jaccard_overlap = False
|
||||
else:
|
||||
has_jaccard_overlap = True
|
||||
if sampler[8] == 0 and sampler[9] == 0:
|
||||
has_object_coverage = False
|
||||
else:
|
||||
has_object_coverage = True
|
||||
|
||||
if not has_jaccard_overlap and not has_object_coverage:
|
||||
return True
|
||||
found = False
|
||||
for i in range(len(gt_bboxes)):
|
||||
object_bbox = [
|
||||
gt_bboxes[i][0], gt_bboxes[i][1], gt_bboxes[i][2], gt_bboxes[i][3]
|
||||
]
|
||||
if has_jaccard_overlap:
|
||||
overlap = jaccard_overlap(sample_bbox, object_bbox)
|
||||
if sampler[6] != 0 and \
|
||||
overlap < sampler[6]:
|
||||
continue
|
||||
if sampler[7] != 0 and \
|
||||
overlap > sampler[7]:
|
||||
continue
|
||||
found = True
|
||||
if has_object_coverage:
|
||||
object_coverage = bbox_coverage(object_bbox, sample_bbox)
|
||||
if sampler[8] != 0 and \
|
||||
object_coverage < sampler[8]:
|
||||
continue
|
||||
if sampler[9] != 0 and \
|
||||
object_coverage > sampler[9]:
|
||||
continue
|
||||
found = True
|
||||
if found:
|
||||
return True
|
||||
return found
|
||||
|
||||
|
||||
def crop_image_sampling(img, sample_bbox, image_width, image_height,
|
||||
target_size):
|
||||
# no clipping here
|
||||
xmin = int(sample_bbox[0] * image_width)
|
||||
xmax = int(sample_bbox[2] * image_width)
|
||||
ymin = int(sample_bbox[1] * image_height)
|
||||
ymax = int(sample_bbox[3] * image_height)
|
||||
|
||||
w_off = xmin
|
||||
h_off = ymin
|
||||
width = xmax - xmin
|
||||
height = ymax - ymin
|
||||
cross_xmin = max(0.0, float(w_off))
|
||||
cross_ymin = max(0.0, float(h_off))
|
||||
cross_xmax = min(float(w_off + width - 1.0), float(image_width))
|
||||
cross_ymax = min(float(h_off + height - 1.0), float(image_height))
|
||||
cross_width = cross_xmax - cross_xmin
|
||||
cross_height = cross_ymax - cross_ymin
|
||||
|
||||
roi_xmin = 0 if w_off >= 0 else abs(w_off)
|
||||
roi_ymin = 0 if h_off >= 0 else abs(h_off)
|
||||
roi_width = cross_width
|
||||
roi_height = cross_height
|
||||
|
||||
roi_y1 = int(roi_ymin)
|
||||
roi_y2 = int(roi_ymin + roi_height)
|
||||
roi_x1 = int(roi_xmin)
|
||||
roi_x2 = int(roi_xmin + roi_width)
|
||||
|
||||
cross_y1 = int(cross_ymin)
|
||||
cross_y2 = int(cross_ymin + cross_height)
|
||||
cross_x1 = int(cross_xmin)
|
||||
cross_x2 = int(cross_xmin + cross_width)
|
||||
|
||||
sample_img = np.zeros((height, width, 3))
|
||||
sample_img[roi_y1: roi_y2, roi_x1: roi_x2] = \
|
||||
img[cross_y1: cross_y2, cross_x1: cross_x2]
|
||||
|
||||
sample_img = cv2.resize(
|
||||
sample_img, (target_size, target_size), interpolation=cv2.INTER_AREA)
|
||||
|
||||
return sample_img
|
||||
|
||||
|
||||
def is_poly(segm):
|
||||
assert isinstance(segm, (list, dict)), \
|
||||
"Invalid segm type: {}".format(type(segm))
|
||||
return isinstance(segm, list)
|
||||
|
||||
|
||||
def gaussian_radius(bbox_size, min_overlap):
|
||||
height, width = bbox_size
|
||||
|
||||
a1 = 1
|
||||
b1 = (height + width)
|
||||
c1 = width * height * (1 - min_overlap) / (1 + min_overlap)
|
||||
sq1 = np.sqrt(b1**2 - 4 * a1 * c1)
|
||||
radius1 = (b1 + sq1) / (2 * a1)
|
||||
|
||||
a2 = 4
|
||||
b2 = 2 * (height + width)
|
||||
c2 = (1 - min_overlap) * width * height
|
||||
sq2 = np.sqrt(b2**2 - 4 * a2 * c2)
|
||||
radius2 = (b2 + sq2) / 2
|
||||
|
||||
a3 = 4 * min_overlap
|
||||
b3 = -2 * min_overlap * (height + width)
|
||||
c3 = (min_overlap - 1) * width * height
|
||||
sq3 = np.sqrt(b3**2 - 4 * a3 * c3)
|
||||
radius3 = (b3 + sq3) / 2
|
||||
return min(radius1, radius2, radius3)
|
||||
|
||||
|
||||
def draw_gaussian(heatmap, center, radius, k=1, delte=6):
|
||||
diameter = 2 * radius + 1
|
||||
sigma = diameter / delte
|
||||
gaussian = gaussian2D((diameter, diameter), sigma_x=sigma, sigma_y=sigma)
|
||||
|
||||
x, y = center
|
||||
|
||||
height, width = heatmap.shape[0:2]
|
||||
|
||||
left, right = min(x, radius), min(width - x, radius + 1)
|
||||
top, bottom = min(y, radius), min(height - y, radius + 1)
|
||||
|
||||
masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right]
|
||||
masked_gaussian = gaussian[radius - top:radius + bottom, radius - left:
|
||||
radius + right]
|
||||
np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap)
|
||||
|
||||
|
||||
def gaussian2D(shape, sigma_x=1, sigma_y=1):
|
||||
m, n = [(ss - 1.) / 2. for ss in shape]
|
||||
y, x = np.ogrid[-m:m + 1, -n:n + 1]
|
||||
|
||||
h = np.exp(-(x * x / (2 * sigma_x * sigma_x) + y * y / (2 * sigma_y *
|
||||
sigma_y)))
|
||||
h[h < np.finfo(h.dtype).eps * h.max()] = 0
|
||||
return h
|
||||
|
||||
|
||||
def draw_umich_gaussian(heatmap, center, radius, k=1):
|
||||
"""
|
||||
draw_umich_gaussian, refer to https://github.com/xingyizhou/CenterNet/blob/master/src/lib/utils/image.py#L126
|
||||
"""
|
||||
diameter = 2 * radius + 1
|
||||
gaussian = gaussian2D(
|
||||
(diameter, diameter), sigma_x=diameter / 6, sigma_y=diameter / 6)
|
||||
|
||||
x, y = int(center[0]), int(center[1])
|
||||
|
||||
height, width = heatmap.shape[0:2]
|
||||
|
||||
left, right = min(x, radius), min(width - x, radius + 1)
|
||||
top, bottom = min(y, radius), min(height - y, radius + 1)
|
||||
|
||||
masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right]
|
||||
masked_gaussian = gaussian[radius - top:radius + bottom, radius - left:
|
||||
radius + right]
|
||||
if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0:
|
||||
np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap)
|
||||
return heatmap
|
||||
|
||||
|
||||
def get_border(border, size):
|
||||
i = 1
|
||||
while size - border // i <= border // i:
|
||||
i *= 2
|
||||
return border // i
|
||||
3797
rtdetr_paddle/ppdet/data/transform/operators.py
Normal file
3797
rtdetr_paddle/ppdet/data/transform/operators.py
Normal file
File diff suppressed because it is too large
Load Diff
71
rtdetr_paddle/ppdet/data/utils.py
Normal file
71
rtdetr_paddle/ppdet/data/utils.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numbers
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
from collections.abc import Sequence, Mapping
|
||||
except:
|
||||
from collections import Sequence, Mapping
|
||||
|
||||
|
||||
def default_collate_fn(batch):
|
||||
"""
|
||||
Default batch collating function for :code:`paddle.io.DataLoader`,
|
||||
get input data as a list of sample datas, each element in list
|
||||
if the data of a sample, and sample data should composed of list,
|
||||
dictionary, string, number, numpy array, this
|
||||
function will parse input data recursively and stack number,
|
||||
numpy array and paddle.Tensor datas as batch datas. e.g. for
|
||||
following input data:
|
||||
[{'image': np.array(shape=[3, 224, 224]), 'label': 1},
|
||||
{'image': np.array(shape=[3, 224, 224]), 'label': 3},
|
||||
{'image': np.array(shape=[3, 224, 224]), 'label': 4},
|
||||
{'image': np.array(shape=[3, 224, 224]), 'label': 5},]
|
||||
|
||||
|
||||
This default collate function zipped each number and numpy array
|
||||
field together and stack each field as the batch field as follows:
|
||||
{'image': np.array(shape=[4, 3, 224, 224]), 'label': np.array([1, 3, 4, 5])}
|
||||
Args:
|
||||
batch(list of sample data): batch should be a list of sample data.
|
||||
|
||||
Returns:
|
||||
Batched data: batched each number, numpy array and paddle.Tensor
|
||||
in input data.
|
||||
"""
|
||||
sample = batch[0]
|
||||
if isinstance(sample, np.ndarray):
|
||||
batch = np.stack(batch, axis=0)
|
||||
return batch
|
||||
elif isinstance(sample, numbers.Number):
|
||||
batch = np.array(batch)
|
||||
return batch
|
||||
elif isinstance(sample, (str, bytes)):
|
||||
return batch
|
||||
elif isinstance(sample, Mapping):
|
||||
return {
|
||||
key: default_collate_fn([d[key] for d in batch])
|
||||
for key in sample
|
||||
}
|
||||
elif isinstance(sample, Sequence):
|
||||
sample_fields_num = len(sample)
|
||||
if not all(len(sample) == sample_fields_num for sample in iter(batch)):
|
||||
raise RuntimeError(
|
||||
"fileds number not same among samples in a batch")
|
||||
return [default_collate_fn(fields) for fields in zip(*batch)]
|
||||
|
||||
raise TypeError("batch data con only contains: tensor, numpy.ndarray, "
|
||||
"dict, list, number, but got {}".format(type(sample)))
|
||||
Reference in New Issue
Block a user