Skip to content

Commit

Permalink
Improve assert error messages + minor cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
rbgirshick committed Sep 9, 2018
1 parent 6852b05 commit ffdf0a9
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 47 deletions.
8 changes: 8 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# This is an example .flake8 config, used when developing *Black* itself.
# Keep in sync with setup.cfg which is used for source packages.

[flake8]
ignore = E203, E266, E501, W503
max-line-length = 80
max-complexity = 18
select = B,C,E,F,W,T4,B9
6 changes: 2 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name="yacs",
version="0.1.2",
version="0.1.3",
author="Ross Girshick",
author_email="[email protected]",
description="Yet Another Configuration System",
Expand All @@ -17,7 +17,5 @@
"Intended Audience :: Science/Research",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
],
install_requires=[
"PyYAML",
],
install_requires=["PyYAML"],
)
113 changes: 72 additions & 41 deletions yacs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
_VALID_TYPES = {dict, tuple, list, str, int, float, bool}
# py2 allow for str and unicode
try:
_VALID_TYPES = _VALID_TYPES.union({unicode})
_VALID_TYPES = _VALID_TYPES.union({unicode}) # noqa: F821
except Exception as _ignore:
pass

Expand Down Expand Up @@ -86,18 +86,20 @@ def __getattr__(self, name):
def __setattr__(self, name, value):
if self.is_frozen():
raise AttributeError(
'Attempted to set "{}" to "{}", but CfgNode is immutable'.format(
"Attempted to set {} to {}, but CfgNode is immutable".format(
name, value
)
)

assert (
name not in self.__dict__
), "Invalid attempt to modify internal CfgNode state"
assert _valid_type(
value, allow_cfg_node=True
), "Invalid type {} for key {}; valid types = {}".format(
type(value), name, _VALID_TYPES
_assert_with_logging(
name not in self.__dict__,
"Invalid attempt to modify internal CfgNode state: {}".format(name),
)
_assert_with_logging(
_valid_type(value, allow_cfg_node=True),
"Invalid type {} for key {}; valid types = {}".format(
type(value), name, _VALID_TYPES
),
)

self[name] = value
Expand All @@ -111,17 +113,22 @@ def merge_from_file(self, cfg_filename):
"""Load a yaml config file and merge it this CfgNode."""
with open(cfg_filename, "r") as f:
cfg = CfgNode(load_cfg(f))
_merge_a_into_b(cfg, self, self)
_merge_a_into_b(cfg, self, self, [])

def merge_from_other_cfg(self, cfg_other):
"""Merge `cfg_other` into this CfgNode."""
_merge_a_into_b(cfg_other, self, self)
_merge_a_into_b(cfg_other, self, self, [])

def merge_from_list(self, cfg_list):
"""Merge config (keys, values) in a list (e.g., from command line) into
this CfgNode. For example, `cfg_list = ['FOO.BAR', 0.5]`.
"""
assert len(cfg_list) % 2 == 0
_assert_with_logging(
len(cfg_list) % 2 == 0,
"Override list has odd length: {}; it must be a list of pairs".format(
cfg_list
),
)
root = self
for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]):
if root.key_is_deprecated(full_key):
Expand All @@ -131,10 +138,12 @@ def merge_from_list(self, cfg_list):
key_list = full_key.split(".")
d = self
for subkey in key_list[:-1]:
assert subkey in d, "Non-existent key: {}".format(full_key)
_assert_with_logging(
subkey in d, "Non-existent key: {}".format(full_key)
)
d = d[subkey]
subkey = key_list[-1]
assert subkey in d, "Non-existent key: {}".format(full_key)
_assert_with_logging(subkey in d, "Non-existent key: {}".format(full_key))
value = _decode_cfg_value(v)
value = _check_and_coerce_cfg_value_type(value, d[subkey], subkey, full_key)
d[subkey] = value
Expand Down Expand Up @@ -172,19 +181,21 @@ def register_deprecated_key(self, key):
"""Register key (e.g. `FOO.BAR`) a deprecated option. When merging deprecated
keys a warning is generated and the key is ignored.
"""
assert (
key not in self.__dict__[CfgNode.DEPRECATED_KEYS]
), "key '{}' is already registered as a deprecated key".format(key)
_assert_with_logging(
key not in self.__dict__[CfgNode.DEPRECATED_KEYS],
"key {} is already registered as a deprecated key".format(key),
)
self.__dict__[CfgNode.DEPRECATED_KEYS].add(key)

def register_renamed_key(self, old_name, new_name, message=None):
"""Register a key as having been renamed from `old_name` to `new_name`.
When merging a renamed key, an exception is thrown alerting to user to
the fact that the key has been renamed.
"""
assert (
old_name not in self.__dict__[CfgNode.RENAMED_KEYS]
), "key '{}' is already registered as a renamed cfg key".format(old_name)
_assert_with_logging(
old_name not in self.__dict__[CfgNode.RENAMED_KEYS],
"key {} is already registered as a renamed cfg key".format(old_name),
)
value = new_name
if message:
value = (new_name, message)
Expand Down Expand Up @@ -217,9 +228,12 @@ def raise_key_rename_error(self, full_key):

def load_cfg(cfg_file_or_string):
"""Load a cfg from a file or string."""
assert isinstance(
cfg_file_or_string, _FILE_TYPES + (str,)
), "Expected {} or {} got {}".format(_FILE_TYPES, str, type(cfg_file_or_string))
_assert_with_logging(
isinstance(cfg_file_or_string, _FILE_TYPES + (str,)),
"Expected first argument to be of type {} or {}, but it was {}".format(
_FILE_TYPES, str, type(cfg_file_or_string)
),
)
if isinstance(cfg_file_or_string, _FILE_TYPES):
cfg_file_or_string = "".join(cfg_file_or_string.readlines())
cfg_as_dict = yaml.safe_load(cfg_file_or_string)
Expand All @@ -229,52 +243,64 @@ def load_cfg(cfg_file_or_string):
def _to_dict(cfg_node):
"""Recursively convert all CfgNode objects to dict objects."""

def convert_to_dict(cfg_node):
def convert_to_dict(cfg_node, key_list):
if not isinstance(cfg_node, CfgNode):
assert _valid_type(cfg_node)
_assert_with_logging(
_valid_type(cfg_node),
"Key {} with value {} is not a valid type; valid types: {}".format(
".".join(key_list), type(cfg_node), _VALID_TYPES
),
)
return cfg_node
else:
cfg_dict = dict(cfg_node)
for k, v in cfg_dict.items():
cfg_dict[k] = convert_to_dict(v)
cfg_dict[k] = convert_to_dict(v, key_list + [k])
return cfg_dict

return convert_to_dict(cfg_node)
return convert_to_dict(cfg_node, [])


def _to_cfg_node(cfg_dict):
"""Recursively convert all dict objects to CfgNode objects."""

def convert_to_cfg_node(cfg_dict):
def convert_to_cfg_node(cfg_dict, key_list):
if type(cfg_dict) is not dict:
assert _valid_type(cfg_dict)
_assert_with_logging(
_valid_type(cfg_dict),
"Key {} with value {} is not a valid type; valid types: {}".format(
".".join(key_list), type(cfg_dict), _VALID_TYPES
),
)
return cfg_dict
else:
cfg_node = CfgNode(cfg_dict)
for k, v in cfg_node.items():
cfg_node[k] = convert_to_cfg_node(v)
cfg_node[k] = convert_to_cfg_node(v, key_list + [k])
return cfg_node

return convert_to_cfg_node(cfg_dict)
return convert_to_cfg_node(cfg_dict, [])


def _valid_type(value, allow_cfg_node=False):
return (type(value) in _VALID_TYPES) or (allow_cfg_node and type(value) == CfgNode)


def _merge_a_into_b(a, b, root, stack=None):
def _merge_a_into_b(a, b, root, stack):
"""Merge config dictionary a into config dictionary b, clobbering the
options in b whenever they are also specified in a.
"""
assert isinstance(a, CfgNode), "`a` (cur type {}) must be an instance of {}".format(
type(a), CfgNode
_assert_with_logging(
isinstance(a, CfgNode),
"`a` (cur type {}) must be an instance of {}".format(type(a), CfgNode),
)
assert isinstance(b, CfgNode), "`b` (cur type {}) must be an instance of {}".format(
type(b), CfgNode
_assert_with_logging(
isinstance(b, CfgNode),
"`b` (cur type {}) must be an instance of {}".format(type(b), CfgNode),
)

for k, v_ in a.items():
full_key = ".".join(stack) + "." + k if stack is not None else k
full_key = ".".join(stack + [k])
# a must specify keys that are in b
if k not in b:
if root.key_is_deprecated(full_key):
Expand All @@ -291,8 +317,7 @@ def _merge_a_into_b(a, b, root, stack=None):
# Recursively merge dicts
if isinstance(v, CfgNode):
try:
stack_push = [k] if stack is None else stack + [k]
_merge_a_into_b(v, b[k], root, stack=stack_push)
_merge_a_into_b(v, b[k], root, stack + [k])
except BaseException:
raise
else:
Expand Down Expand Up @@ -357,8 +382,8 @@ def conditional_cast(from_type, to_type):
casts = [(tuple, list), (list, tuple)]
# For py2: allow converting from str (bytes) to a unicode string
try:
casts.append((str, unicode))
except Exception as _ignore:
casts.append((str, unicode)) # noqa: F821
except Exception:
pass

for (from_type, to_type) in casts:
Expand All @@ -372,3 +397,9 @@ def conditional_cast(from_type, to_type):
original_type, replacement_type, original, replacement, full_key
)
)


def _assert_with_logging(cond, msg):
if not cond:
logger.debug(msg)
assert cond, msg
18 changes: 16 additions & 2 deletions yacs/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from yacs.config import CfgNode as CN

try:
_ignore = unicode
_ignore = unicode # noqa: F821
PY2 = True
except Exception as _ignore:
PY2 = False
Expand Down Expand Up @@ -119,7 +119,7 @@ def test_merge_cfg_from_cfg(self):
cfg2 = CN()
cfg2.A_UNICODE_KEY = b"bar"
cfg.merge_from_other_cfg(cfg2)
assert type(cfg.A_UNICODE_KEY) == unicode
assert type(cfg.A_UNICODE_KEY) == unicode # noqa: F821
assert cfg.A_UNICODE_KEY == u"bar"

# Test: merge with invalid type
Expand Down Expand Up @@ -169,6 +169,18 @@ def test_deprecated_key_from_list(self):
with self.assertRaises(AttributeError):
_ = cfg.MODEL.DILATION # noqa

def test_nonexistant_key_from_list(self):
cfg = get_cfg()
opts = ["MODEL.DOES_NOT_EXIST", "IGNORE"]
with self.assertRaises(AssertionError):
cfg.merge_from_list(opts)

def test_load_cfg_invalid_type(self):
# FOO.BAR.QUUX will have type None, which is not allowed
cfg_string = "FOO:\n BAR:\n QUUX:"
with self.assertRaises(AssertionError):
yacs.config.load_cfg(cfg_string)

def test_deprecated_key_from_file(self):
# You should see logger messages like:
# "Deprecated config key (ignoring): MODEL.DILATION"
Expand Down Expand Up @@ -214,4 +226,6 @@ def test_invalid_type(self):

if __name__ == "__main__":
logging.basicConfig()
yacs_logger = logging.getLogger("yacs.config")
yacs_logger.setLevel(logging.DEBUG)
unittest.main()

0 comments on commit ffdf0a9

Please sign in to comment.