first commit
This commit is contained in:
13
rtdetr_paddle/ppdet/core/config/__init__.py
Normal file
13
rtdetr_paddle/ppdet/core/config/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
248
rtdetr_paddle/ppdet/core/config/schema.py
Normal file
248
rtdetr_paddle/ppdet/core/config/schema.py
Normal file
@@ -0,0 +1,248 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import print_function
|
||||
from __future__ import division
|
||||
|
||||
import inspect
|
||||
import importlib
|
||||
import re
|
||||
|
||||
try:
|
||||
from docstring_parser import parse as doc_parse
|
||||
except Exception:
|
||||
|
||||
def doc_parse(*args):
|
||||
pass
|
||||
|
||||
|
||||
try:
|
||||
from typeguard import check_type
|
||||
except Exception:
|
||||
|
||||
def check_type(*args):
|
||||
pass
|
||||
|
||||
|
||||
__all__ = ['SchemaValue', 'SchemaDict', 'SharedConfig', 'extract_schema']
|
||||
|
||||
|
||||
class SchemaValue(object):
|
||||
def __init__(self, name, doc='', type=None):
|
||||
super(SchemaValue, self).__init__()
|
||||
self.name = name
|
||||
self.doc = doc
|
||||
self.type = type
|
||||
|
||||
def set_default(self, value):
|
||||
self.default = value
|
||||
|
||||
def has_default(self):
|
||||
return hasattr(self, 'default')
|
||||
|
||||
|
||||
class SchemaDict(dict):
|
||||
def __init__(self, **kwargs):
|
||||
super(SchemaDict, self).__init__()
|
||||
self.schema = {}
|
||||
self.strict = False
|
||||
self.doc = ""
|
||||
self.update(kwargs)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
# XXX also update regular dict to SchemaDict??
|
||||
if isinstance(value, dict) and key in self and isinstance(self[key],
|
||||
SchemaDict):
|
||||
self[key].update(value)
|
||||
else:
|
||||
super(SchemaDict, self).__setitem__(key, value)
|
||||
|
||||
def __missing__(self, key):
|
||||
if self.has_default(key):
|
||||
return self.schema[key].default
|
||||
elif key in self.schema:
|
||||
return self.schema[key]
|
||||
else:
|
||||
raise KeyError(key)
|
||||
|
||||
def copy(self):
|
||||
newone = SchemaDict()
|
||||
newone.__dict__.update(self.__dict__)
|
||||
newone.update(self)
|
||||
return newone
|
||||
|
||||
def set_schema(self, key, value):
|
||||
assert isinstance(value, SchemaValue)
|
||||
self.schema[key] = value
|
||||
|
||||
def set_strict(self, strict):
|
||||
self.strict = strict
|
||||
|
||||
def has_default(self, key):
|
||||
return key in self.schema and self.schema[key].has_default()
|
||||
|
||||
def is_default(self, key):
|
||||
if not self.has_default(key):
|
||||
return False
|
||||
if hasattr(self[key], '__dict__'):
|
||||
return True
|
||||
else:
|
||||
return key not in self or self[key] == self.schema[key].default
|
||||
|
||||
def find_default_keys(self):
|
||||
return [
|
||||
k for k in list(self.keys()) + list(self.schema.keys())
|
||||
if self.is_default(k)
|
||||
]
|
||||
|
||||
def mandatory(self):
|
||||
return any([k for k in self.schema.keys() if not self.has_default(k)])
|
||||
|
||||
def find_missing_keys(self):
|
||||
missing = [
|
||||
k for k in self.schema.keys()
|
||||
if k not in self and not self.has_default(k)
|
||||
]
|
||||
placeholders = [k for k in self if self[k] in ('<missing>', '<value>')]
|
||||
return missing + placeholders
|
||||
|
||||
def find_extra_keys(self):
|
||||
return list(set(self.keys()) - set(self.schema.keys()))
|
||||
|
||||
def find_mismatch_keys(self):
|
||||
mismatch_keys = []
|
||||
for arg in self.schema.values():
|
||||
if arg.type is not None:
|
||||
try:
|
||||
check_type("{}.{}".format(self.name, arg.name),
|
||||
self[arg.name], arg.type)
|
||||
except Exception:
|
||||
mismatch_keys.append(arg.name)
|
||||
return mismatch_keys
|
||||
|
||||
def validate(self):
|
||||
missing_keys = self.find_missing_keys()
|
||||
if missing_keys:
|
||||
raise ValueError("Missing param for class<{}>: {}".format(
|
||||
self.name, ", ".join(missing_keys)))
|
||||
extra_keys = self.find_extra_keys()
|
||||
if extra_keys and self.strict:
|
||||
raise ValueError("Extraneous param for class<{}>: {}".format(
|
||||
self.name, ", ".join(extra_keys)))
|
||||
mismatch_keys = self.find_mismatch_keys()
|
||||
if mismatch_keys:
|
||||
raise TypeError("Wrong param type for class<{}>: {}".format(
|
||||
self.name, ", ".join(mismatch_keys)))
|
||||
|
||||
|
||||
class SharedConfig(object):
|
||||
"""
|
||||
Representation class for `__shared__` annotations, which work as follows:
|
||||
|
||||
- if `key` is set for the module in config file, its value will take
|
||||
precedence
|
||||
- if `key` is not set for the module but present in the config file, its
|
||||
value will be used
|
||||
- otherwise, use the provided `default_value` as fallback
|
||||
|
||||
Args:
|
||||
key: config[key] will be injected
|
||||
default_value: fallback value
|
||||
"""
|
||||
|
||||
def __init__(self, key, default_value=None):
|
||||
super(SharedConfig, self).__init__()
|
||||
self.key = key
|
||||
self.default_value = default_value
|
||||
|
||||
|
||||
def extract_schema(cls):
|
||||
"""
|
||||
Extract schema from a given class
|
||||
|
||||
Args:
|
||||
cls (type): Class from which to extract.
|
||||
|
||||
Returns:
|
||||
schema (SchemaDict): Extracted schema.
|
||||
"""
|
||||
ctor = cls.__init__
|
||||
# python 2 compatibility
|
||||
if hasattr(inspect, 'getfullargspec'):
|
||||
argspec = inspect.getfullargspec(ctor)
|
||||
annotations = argspec.annotations
|
||||
has_kwargs = argspec.varkw is not None
|
||||
else:
|
||||
argspec = inspect.getfullargspec(ctor)
|
||||
# python 2 type hinting workaround, see pep-3107
|
||||
# however, since `typeguard` does not support python 2, type checking
|
||||
# is still python 3 only for now
|
||||
annotations = getattr(ctor, '__annotations__', {})
|
||||
has_kwargs = argspec.varkw is not None
|
||||
|
||||
names = [arg for arg in argspec.args if arg != 'self']
|
||||
defaults = argspec.defaults
|
||||
num_defaults = argspec.defaults is not None and len(argspec.defaults) or 0
|
||||
num_required = len(names) - num_defaults
|
||||
|
||||
docs = cls.__doc__
|
||||
if docs is None and getattr(cls, '__category__', None) == 'op':
|
||||
docs = cls.__call__.__doc__
|
||||
try:
|
||||
docstring = doc_parse(docs)
|
||||
except Exception:
|
||||
docstring = None
|
||||
|
||||
if docstring is None:
|
||||
comments = {}
|
||||
else:
|
||||
comments = {}
|
||||
for p in docstring.params:
|
||||
match_obj = re.match('^([a-zA-Z_]+[a-zA-Z_0-9]*).*', p.arg_name)
|
||||
if match_obj is not None:
|
||||
comments[match_obj.group(1)] = p.description
|
||||
|
||||
schema = SchemaDict()
|
||||
schema.name = cls.__name__
|
||||
schema.doc = ""
|
||||
if docs is not None:
|
||||
start_pos = docs[0] == '\n' and 1 or 0
|
||||
schema.doc = docs[start_pos:].split("\n")[0].strip()
|
||||
# XXX handle paddle's weird doc convention
|
||||
if '**' == schema.doc[:2] and '**' == schema.doc[-2:]:
|
||||
schema.doc = schema.doc[2:-2].strip()
|
||||
schema.category = hasattr(cls, '__category__') and getattr(
|
||||
cls, '__category__') or 'module'
|
||||
schema.strict = not has_kwargs
|
||||
schema.pymodule = importlib.import_module(cls.__module__)
|
||||
schema.inject = getattr(cls, '__inject__', [])
|
||||
schema.shared = getattr(cls, '__shared__', [])
|
||||
for idx, name in enumerate(names):
|
||||
comment = name in comments and comments[name] or name
|
||||
if name in schema.inject:
|
||||
type_ = None
|
||||
else:
|
||||
type_ = name in annotations and annotations[name] or None
|
||||
value_schema = SchemaValue(name, comment, type_)
|
||||
if name in schema.shared:
|
||||
assert idx >= num_required, "shared config must have default value"
|
||||
default = defaults[idx - num_required]
|
||||
value_schema.set_default(SharedConfig(name, default))
|
||||
elif idx >= num_required:
|
||||
default = defaults[idx - num_required]
|
||||
value_schema.set_default(default)
|
||||
schema.set_schema(name, value_schema)
|
||||
|
||||
return schema
|
||||
118
rtdetr_paddle/ppdet/core/config/yaml_helpers.py
Normal file
118
rtdetr_paddle/ppdet/core/config/yaml_helpers.py
Normal file
@@ -0,0 +1,118 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
|
||||
import yaml
|
||||
from .schema import SharedConfig
|
||||
|
||||
__all__ = ['serializable', 'Callable']
|
||||
|
||||
|
||||
def represent_dictionary_order(self, dict_data):
|
||||
return self.represent_mapping('tag:yaml.org,2002:map', dict_data.items())
|
||||
|
||||
|
||||
def setup_orderdict():
|
||||
from collections import OrderedDict
|
||||
yaml.add_representer(OrderedDict, represent_dictionary_order)
|
||||
|
||||
|
||||
def _make_python_constructor(cls):
|
||||
def python_constructor(loader, node):
|
||||
if isinstance(node, yaml.SequenceNode):
|
||||
args = loader.construct_sequence(node, deep=True)
|
||||
return cls(*args)
|
||||
else:
|
||||
kwargs = loader.construct_mapping(node, deep=True)
|
||||
try:
|
||||
return cls(**kwargs)
|
||||
except Exception as ex:
|
||||
print("Error when construct {} instance from yaml config".
|
||||
format(cls.__name__))
|
||||
raise ex
|
||||
|
||||
return python_constructor
|
||||
|
||||
|
||||
def _make_python_representer(cls):
|
||||
# python 2 compatibility
|
||||
if hasattr(inspect, 'getfullargspec'):
|
||||
argspec = inspect.getfullargspec(cls)
|
||||
else:
|
||||
argspec = inspect.getfullargspec(cls.__init__)
|
||||
argnames = [arg for arg in argspec.args if arg != 'self']
|
||||
|
||||
def python_representer(dumper, obj):
|
||||
if argnames:
|
||||
data = {name: getattr(obj, name) for name in argnames}
|
||||
else:
|
||||
data = obj.__dict__
|
||||
if '_id' in data:
|
||||
del data['_id']
|
||||
return dumper.represent_mapping(u'!{}'.format(cls.__name__), data)
|
||||
|
||||
return python_representer
|
||||
|
||||
|
||||
def serializable(cls):
|
||||
"""
|
||||
Add loader and dumper for given class, which must be
|
||||
"trivially serializable"
|
||||
|
||||
Args:
|
||||
cls: class to be serialized
|
||||
|
||||
Returns: cls
|
||||
"""
|
||||
yaml.add_constructor(u'!{}'.format(cls.__name__),
|
||||
_make_python_constructor(cls))
|
||||
yaml.add_representer(cls, _make_python_representer(cls))
|
||||
return cls
|
||||
|
||||
|
||||
yaml.add_representer(SharedConfig,
|
||||
lambda d, o: d.represent_data(o.default_value))
|
||||
|
||||
|
||||
@serializable
|
||||
class Callable(object):
|
||||
"""
|
||||
Helper to be used in Yaml for creating arbitrary class objects
|
||||
|
||||
Args:
|
||||
full_type (str): the full module path to target function
|
||||
"""
|
||||
|
||||
def __init__(self, full_type, args=[], kwargs={}):
|
||||
super(Callable, self).__init__()
|
||||
self.full_type = full_type
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
def __call__(self):
|
||||
if '.' in self.full_type:
|
||||
idx = self.full_type.rfind('.')
|
||||
module = importlib.import_module(self.full_type[:idx])
|
||||
func_name = self.full_type[idx + 1:]
|
||||
else:
|
||||
try:
|
||||
module = importlib.import_module('builtins')
|
||||
except Exception:
|
||||
module = importlib.import_module('__builtin__')
|
||||
func_name = self.full_type
|
||||
|
||||
func = getattr(module, func_name)
|
||||
return func(*self.args, **self.kwargs)
|
||||
Reference in New Issue
Block a user