# Copyright (c) Ye Liu. Licensed under the MIT License.
import os
import sys
from collections import OrderedDict
from copy import deepcopy
from importlib import import_module
from tempfile import TemporaryDirectory
import nncore
from .binder import bind_getter
from .data import is_seq_of
from .path import abs_path, cp, dir_name, is_file, join, pure_ext
[docs]
class CfgNode(OrderedDict):
"""
An extended :obj:`OrderedDict` class with several practical methods.
This class is an extension of the built-in type :obj:`OrderedDict`. The
interface is the same as a dict object and also allows access config values
as attributes. The input to the init method can be either a single dict or
several named parameters.
"""
@staticmethod
def _set_freeze_state(obj, state):
if isinstance(obj, CfgNode):
super(CfgNode, obj).__setattr__('_frozen', state)
for v in obj.values():
CfgNode._set_freeze_state(v, state)
elif isinstance(obj, (list, tuple)):
for v in obj:
CfgNode._set_freeze_state(v, state)
def __init__(self, *args, **kwargs):
if len(args) > 1:
raise TypeError('too many arguments')
if len(args) == 1:
if isinstance(args[0], dict):
kwargs.update(args[0])
else:
raise TypeError("unsupported type '{}'".format(type(args[0])))
super(CfgNode, self).__setattr__('_frozen', False)
for key, value in kwargs.items():
self[key] = value
def __setitem__(self, key, value):
self._check_freeze_state()
super(CfgNode, self).__setitem__(key, self._parse_value(value))
def __getattr__(self, key):
try:
return self[key]
except KeyError:
msg = "attribute '{}' is not found".format(key)
raise AttributeError(msg)
def __setattr__(self, key, value):
if hasattr(self.__class__, key):
raise AttributeError("attribute '{}' is read-only".format(key))
self._check_freeze_state()
self[key] = value
def __setstate__(self, state):
self._set_freeze_state(self, state['_frozen'])
def __deepcopy__(self, memo):
def _copy(obj, memo):
try:
return deepcopy(obj, memo)
except TypeError:
return obj
other = self.__class__()
memo[id(self)] = other
for key, value in self.items():
key = _copy(key, memo)
if isinstance(value, (list, tuple)):
value = type(value)(_copy(v, memo) for v in value)
else:
value = _copy(value, memo)
other[key] = value
return other
def __eq__(self, other):
if isinstance(other, Config):
return self.__eq__(other._cfg)
elif not isinstance(other, dict):
return False
elif list(self.keys()) != list(other.keys()):
return False
for key in self.keys():
if self[key] != other[key]:
return False
return True
def __repr__(self):
return super(OrderedDict, self).__repr__()
def _parse_value(self, value):
if isinstance(value, dict):
value = self.__class__(**value)
elif isinstance(value, (list, tuple)):
value = type(value)(self._parse_value(v) for v in value)
return value
def _check_freeze_state(self):
if self._frozen:
raise RuntimeError('can not modify a frozen {} object'.format(
self.__class__))
def freeze(self):
self._set_freeze_state(self, True)
def unfreeze(self):
self._set_freeze_state(self, False)
[docs]
def copy(self):
return deepcopy(self)
[docs]
def update(self, *args, **kwargs):
other = dict()
if len(args) == 1:
other.update(args[0])
elif len(args) > 1:
raise TypeError('too many arguments')
other.update(kwargs)
for key, value in other.items():
if key not in self or not isinstance(
self[key], dict) or not isinstance(value, dict):
self[key] = self._parse_value(value)
else:
self[key].update(value)
def merge_from(self, other):
def _insert(ori, value, cfg):
assert isinstance(cfg, (dict, int))
if isinstance(cfg, dict):
assert is_seq_of(cfg['index'], int) or isinstance(
cfg['index'], int)
if isinstance(cfg['index'], (list, tuple)):
assert isinstance(cfg['value'], (list, tuple)) and len(
cfg['index']) == len(cfg['value'])
idxs, vals = cfg['index'], cfg['value']
else:
idxs, vals = [cfg['index']], [cfg['value']]
for i, idx in enumerate(idxs):
ori.insert(idx, vals[i])
elif type(cfg) is int:
ori.insert(cfg, value)
else:
ori.append(value)
return ori
def _update(ori, value, cfg):
assert isinstance(cfg, (dict, int))
if isinstance(cfg, dict):
assert is_seq_of(cfg['index'], int) or isinstance(
cfg['index'], int)
if isinstance(cfg['index'], (list, tuple)):
assert isinstance(cfg['value'], (list, tuple)) and len(
cfg['index']) == len(cfg['value'])
idxs, vals = cfg['index'], cfg['value']
else:
idxs, vals = [cfg['index']], [cfg['value']]
for i, idx in enumerate(idxs):
if isinstance(ori[idx], dict) and isinstance(
vals[i], dict):
ori[idx].merge_from(vals[i])
else:
ori[idx] = vals[i]
elif isinstance(ori[cfg], dict):
ori[cfg].merge_from(value)
else:
ori[cfg] = value
return ori
assert isinstance(other, dict)
other = self.__class__(**other)
for key, value in other.items():
if key.startswith('_'):
continue
if isinstance(value, dict):
refine = value.pop('_refine_', None)
repeat = value.pop('_repeat_', None)
delete = value.pop('_delete_', None)
insert = value.pop('_insert_', None)
update = value.pop('_update_', None)
do_insert = type(insert) is int or insert
do_update = type(update) is int or update
if key not in self and refine:
continue
if key in self and isinstance(repeat, int):
self[key] = [self[key] for _ in range(repeat)]
continue
if key in self and isinstance(self[key], dict):
assert not do_update and not (delete and do_insert)
if do_insert:
self[key] = _insert([self[key]], value, insert)
continue
elif not delete:
self[key].merge_from(value)
continue
if key in self and isinstance(self[key], (list, tuple)):
if do_insert or do_update:
is_tuple = isinstance(self[key], tuple)
if is_tuple:
self[key] = list(self[key])
if do_insert:
self[key] = _insert(self[key], value, insert)
if do_update:
self[key] = _update(self[key], value, update)
if is_tuple:
self[key] = tuple(self[key])
continue
assert not do_insert and not do_update
tmp_node = self.__class__()
tmp_node.merge_from(value)
if tmp_node or not value:
self[key] = tmp_node
elif value == '_delete_':
if key in self:
del self[key]
else:
self[key] = value
def to_dict(self, ordered=False):
base = OrderedDict() if ordered else dict()
for key, value in self.items():
if isinstance(value, self.__class__):
base[key] = value.to_dict()
elif isinstance(value, (list, tuple)):
base[key] = type(value)(
v.to_dict() if isinstance(v, self.__class__) else v
for v in value)
else:
base[key] = value
return base
def to_json(self, indent=2):
return nncore.dumps(self.to_dict(), format='json', indent=indent)
[docs]
@bind_getter('filename')
class Config(CfgNode):
"""
A facility for better :obj:`CfgNode` objects.
This class inherits from :obj:`CfgNode` and it can be initialized from a
config file. Users can use the static method :obj:`Config.from_file` to
create a :obj:`Config` object.
"""
[docs]
@staticmethod
def from_file(filename, freeze=False):
"""
Build a :obj:`Config` object from a file.
Args:
filename (str): Path to the config file. Currently supported
formats include ``py``, ``json``, and ``yaml/yml``.
freeze (bool, optional): Whether to freeze the config after
initialization. Default: ``False``.
Returns:
:obj:`Config`: The constructed config object.
"""
filename = abs_path(filename)
is_file(filename, raise_error=True)
format = pure_ext(filename)
if format == 'py':
with TemporaryDirectory() as tmp:
mod_name = str(int.from_bytes(os.urandom(2), 'big'))
cp(filename, join(tmp, '{}.py'.format(mod_name)))
sys.path.insert(0, tmp)
mod = import_module(mod_name)
sys.path.pop(0)
cfg = {
k: v
for k, v in mod.__dict__.items()
if not k.startswith('__') or not k.endswith('__')
}
elif format in ('json', 'yml', 'yaml'):
cfg = nncore.load(filename)
else:
raise TypeError("unsupported format: '{}'".format(format))
if '_base_' in cfg:
base = cfg.pop('_base_')
if isinstance(base, str):
base = [base]
_cfg = CfgNode()
for name in base:
if name.endswith(('.py', '.json', '.yml', '.yaml')):
path = join(dir_name(filename), name)
_cfg.merge_from(Config.from_file(path))
else:
curr_path = os.getcwd()
if curr_path not in sys.path:
sys.path.insert(0, curr_path)
import_module(name)
_cfg.merge_from(cfg)
cfg = _cfg
return Config(cfg, filename=filename, freeze=freeze)
def __init__(self, *args, filename=None, freeze=False, **kwargs):
super(Config, self).__init__(*args, **kwargs)
super(CfgNode, self).__setattr__('_filename', filename)
if freeze:
self.freeze()
def __repr__(self):
return '{}({}frozen={}): {}'.format(
self.__class__.__name__, '' if self._filename is None else
"filename='{}', ".format(self._filename), self._frozen,
super(Config, self).__repr__())
@property
def text(self):
def _indent(a_str):
tokens = a_str.split('\n')
if len(tokens) == 1:
return a_str
first = tokens.pop(0)
tokens = [' ' * 4 + line for line in tokens]
return '{}\n{}'.format(first, '\n'.join(tokens))
def _basic(key, value, blank=True):
base_str = '{} = {}' if blank else '{}={}'
if isinstance(value, dict):
v_str = _dict(value)
elif isinstance(value, str):
v_str = "'{}'".format(value)
else:
v_str = str(value)
a_str = v_str if key is None else base_str.format(key, v_str)
return _indent(a_str)
def _iterable(key, value, blank=True):
base_str, tokens = '{} = {}' if blank else '{}={}', []
prefix = '\n' if len(value) > 1 else ''
expand = any(isinstance(v, (dict, list, tuple)) for v in value)
for v in value:
if isinstance(v, dict):
if len(v) > 1:
a_str = _indent('\n' + _dict(v))
else:
a_str = _dict(v)
tokens.append(_indent(prefix + 'dict({})'.format(a_str)))
elif isinstance(v, (list, tuple)):
tokens.append(_indent(prefix + _iterable(None, v)))
else:
a_str = _basic(None, v)
tokens.append(_indent(prefix + a_str) if expand else a_str)
left, right = ('[', ']') if isinstance(value, list) else ('(', ')')
sep = ',' if expand else ', '
v_str = '{}{}{}'.format(left, sep.join(tokens), right)
return v_str if key is None else base_str.format(key, v_str)
def _dict(value, parent=False):
base_str, tokens = '{} = dict({})' if parent else '{}=dict({})', []
for i, (k, v) in enumerate(value.items()):
end = '' if parent or i >= len(value) - 1 else ','
if isinstance(v, dict):
if len(v) > 1:
a_str = base_str.format(str(k), '\n' + _dict(v))
a_str = _indent(a_str) + end
else:
a_str = base_str.format(str(k), _dict(v)) + end
elif isinstance(v, (list, tuple)):
a_str = _iterable(k, v, blank=parent) + end
else:
a_str = _basic(k, v, blank=parent) + end
tokens.append(a_str)
return '\n'.join(tokens)
text = _dict(self.to_dict(), parent=True)
if self._filename is not None:
text = '{}\n'.format(self._filename) + text
return text