From ffdf0a9268061c6b85907b2b25993fa33f7ec697 Mon Sep 17 00:00:00 2001 From: Ross Girshick Date: Sun, 9 Sep 2018 21:24:10 +0200 Subject: [PATCH] Improve assert error messages + minor cleanup --- .flake8 | 8 ++++ setup.py | 6 +-- yacs/config.py | 113 +++++++++++++++++++++++++++++++------------------ yacs/tests.py | 18 +++++++- 4 files changed, 98 insertions(+), 47 deletions(-) create mode 100644 .flake8 diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..c286ad0 --- /dev/null +++ b/.flake8 @@ -0,0 +1,8 @@ +# This is an example .flake8 config, used when developing *Black* itself. +# Keep in sync with setup.cfg which is used for source packages. + +[flake8] +ignore = E203, E266, E501, W503 +max-line-length = 80 +max-complexity = 18 +select = B,C,E,F,W,T4,B9 diff --git a/setup.py b/setup.py index 691289a..5775f37 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name="yacs", - version="0.1.2", + version="0.1.3", author="Ross Girshick", author_email="ross.girshick@gmail.com", description="Yet Another Configuration System", @@ -17,7 +17,5 @@ "Intended Audience :: Science/Research", "Topic :: Scientific/Engineering :: Artificial Intelligence", ], - install_requires=[ - "PyYAML", - ], + install_requires=["PyYAML"], ) diff --git a/yacs/config.py b/yacs/config.py index cf37677..5c484e0 100644 --- a/yacs/config.py +++ b/yacs/config.py @@ -40,7 +40,7 @@ _VALID_TYPES = {dict, tuple, list, str, int, float, bool} # py2 allow for str and unicode try: - _VALID_TYPES = _VALID_TYPES.union({unicode}) + _VALID_TYPES = _VALID_TYPES.union({unicode}) # noqa: F821 except Exception as _ignore: pass @@ -86,18 +86,20 @@ def __getattr__(self, name): def __setattr__(self, name, value): if self.is_frozen(): raise AttributeError( - 'Attempted to set "{}" to "{}", but CfgNode is immutable'.format( + "Attempted to set {} to {}, but CfgNode is immutable".format( name, value ) ) - assert ( - name not in self.__dict__ - ), "Invalid attempt to modify internal CfgNode state" - assert _valid_type( - value, allow_cfg_node=True - ), "Invalid type {} for key {}; valid types = {}".format( - type(value), name, _VALID_TYPES + _assert_with_logging( + name not in self.__dict__, + "Invalid attempt to modify internal CfgNode state: {}".format(name), + ) + _assert_with_logging( + _valid_type(value, allow_cfg_node=True), + "Invalid type {} for key {}; valid types = {}".format( + type(value), name, _VALID_TYPES + ), ) self[name] = value @@ -111,17 +113,22 @@ def merge_from_file(self, cfg_filename): """Load a yaml config file and merge it this CfgNode.""" with open(cfg_filename, "r") as f: cfg = CfgNode(load_cfg(f)) - _merge_a_into_b(cfg, self, self) + _merge_a_into_b(cfg, self, self, []) def merge_from_other_cfg(self, cfg_other): """Merge `cfg_other` into this CfgNode.""" - _merge_a_into_b(cfg_other, self, self) + _merge_a_into_b(cfg_other, self, self, []) def merge_from_list(self, cfg_list): """Merge config (keys, values) in a list (e.g., from command line) into this CfgNode. For example, `cfg_list = ['FOO.BAR', 0.5]`. """ - assert len(cfg_list) % 2 == 0 + _assert_with_logging( + len(cfg_list) % 2 == 0, + "Override list has odd length: {}; it must be a list of pairs".format( + cfg_list + ), + ) root = self for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]): if root.key_is_deprecated(full_key): @@ -131,10 +138,12 @@ def merge_from_list(self, cfg_list): key_list = full_key.split(".") d = self for subkey in key_list[:-1]: - assert subkey in d, "Non-existent key: {}".format(full_key) + _assert_with_logging( + subkey in d, "Non-existent key: {}".format(full_key) + ) d = d[subkey] subkey = key_list[-1] - assert subkey in d, "Non-existent key: {}".format(full_key) + _assert_with_logging(subkey in d, "Non-existent key: {}".format(full_key)) value = _decode_cfg_value(v) value = _check_and_coerce_cfg_value_type(value, d[subkey], subkey, full_key) d[subkey] = value @@ -172,9 +181,10 @@ def register_deprecated_key(self, key): """Register key (e.g. `FOO.BAR`) a deprecated option. When merging deprecated keys a warning is generated and the key is ignored. """ - assert ( - key not in self.__dict__[CfgNode.DEPRECATED_KEYS] - ), "key '{}' is already registered as a deprecated key".format(key) + _assert_with_logging( + key not in self.__dict__[CfgNode.DEPRECATED_KEYS], + "key {} is already registered as a deprecated key".format(key), + ) self.__dict__[CfgNode.DEPRECATED_KEYS].add(key) def register_renamed_key(self, old_name, new_name, message=None): @@ -182,9 +192,10 @@ def register_renamed_key(self, old_name, new_name, message=None): When merging a renamed key, an exception is thrown alerting to user to the fact that the key has been renamed. """ - assert ( - old_name not in self.__dict__[CfgNode.RENAMED_KEYS] - ), "key '{}' is already registered as a renamed cfg key".format(old_name) + _assert_with_logging( + old_name not in self.__dict__[CfgNode.RENAMED_KEYS], + "key {} is already registered as a renamed cfg key".format(old_name), + ) value = new_name if message: value = (new_name, message) @@ -217,9 +228,12 @@ def raise_key_rename_error(self, full_key): def load_cfg(cfg_file_or_string): """Load a cfg from a file or string.""" - assert isinstance( - cfg_file_or_string, _FILE_TYPES + (str,) - ), "Expected {} or {} got {}".format(_FILE_TYPES, str, type(cfg_file_or_string)) + _assert_with_logging( + isinstance(cfg_file_or_string, _FILE_TYPES + (str,)), + "Expected first argument to be of type {} or {}, but it was {}".format( + _FILE_TYPES, str, type(cfg_file_or_string) + ), + ) if isinstance(cfg_file_or_string, _FILE_TYPES): cfg_file_or_string = "".join(cfg_file_or_string.readlines()) cfg_as_dict = yaml.safe_load(cfg_file_or_string) @@ -229,52 +243,64 @@ def load_cfg(cfg_file_or_string): def _to_dict(cfg_node): """Recursively convert all CfgNode objects to dict objects.""" - def convert_to_dict(cfg_node): + def convert_to_dict(cfg_node, key_list): if not isinstance(cfg_node, CfgNode): - assert _valid_type(cfg_node) + _assert_with_logging( + _valid_type(cfg_node), + "Key {} with value {} is not a valid type; valid types: {}".format( + ".".join(key_list), type(cfg_node), _VALID_TYPES + ), + ) return cfg_node else: cfg_dict = dict(cfg_node) for k, v in cfg_dict.items(): - cfg_dict[k] = convert_to_dict(v) + cfg_dict[k] = convert_to_dict(v, key_list + [k]) return cfg_dict - return convert_to_dict(cfg_node) + return convert_to_dict(cfg_node, []) def _to_cfg_node(cfg_dict): """Recursively convert all dict objects to CfgNode objects.""" - def convert_to_cfg_node(cfg_dict): + def convert_to_cfg_node(cfg_dict, key_list): if type(cfg_dict) is not dict: - assert _valid_type(cfg_dict) + _assert_with_logging( + _valid_type(cfg_dict), + "Key {} with value {} is not a valid type; valid types: {}".format( + ".".join(key_list), type(cfg_dict), _VALID_TYPES + ), + ) return cfg_dict else: cfg_node = CfgNode(cfg_dict) for k, v in cfg_node.items(): - cfg_node[k] = convert_to_cfg_node(v) + cfg_node[k] = convert_to_cfg_node(v, key_list + [k]) return cfg_node - return convert_to_cfg_node(cfg_dict) + return convert_to_cfg_node(cfg_dict, []) def _valid_type(value, allow_cfg_node=False): return (type(value) in _VALID_TYPES) or (allow_cfg_node and type(value) == CfgNode) -def _merge_a_into_b(a, b, root, stack=None): +def _merge_a_into_b(a, b, root, stack): """Merge config dictionary a into config dictionary b, clobbering the options in b whenever they are also specified in a. """ - assert isinstance(a, CfgNode), "`a` (cur type {}) must be an instance of {}".format( - type(a), CfgNode + _assert_with_logging( + isinstance(a, CfgNode), + "`a` (cur type {}) must be an instance of {}".format(type(a), CfgNode), ) - assert isinstance(b, CfgNode), "`b` (cur type {}) must be an instance of {}".format( - type(b), CfgNode + _assert_with_logging( + isinstance(b, CfgNode), + "`b` (cur type {}) must be an instance of {}".format(type(b), CfgNode), ) for k, v_ in a.items(): - full_key = ".".join(stack) + "." + k if stack is not None else k + full_key = ".".join(stack + [k]) # a must specify keys that are in b if k not in b: if root.key_is_deprecated(full_key): @@ -291,8 +317,7 @@ def _merge_a_into_b(a, b, root, stack=None): # Recursively merge dicts if isinstance(v, CfgNode): try: - stack_push = [k] if stack is None else stack + [k] - _merge_a_into_b(v, b[k], root, stack=stack_push) + _merge_a_into_b(v, b[k], root, stack + [k]) except BaseException: raise else: @@ -357,8 +382,8 @@ def conditional_cast(from_type, to_type): casts = [(tuple, list), (list, tuple)] # For py2: allow converting from str (bytes) to a unicode string try: - casts.append((str, unicode)) - except Exception as _ignore: + casts.append((str, unicode)) # noqa: F821 + except Exception: pass for (from_type, to_type) in casts: @@ -372,3 +397,9 @@ def conditional_cast(from_type, to_type): original_type, replacement_type, original, replacement, full_key ) ) + + +def _assert_with_logging(cond, msg): + if not cond: + logger.debug(msg) + assert cond, msg diff --git a/yacs/tests.py b/yacs/tests.py index 80ac0c0..77a00b8 100644 --- a/yacs/tests.py +++ b/yacs/tests.py @@ -6,7 +6,7 @@ from yacs.config import CfgNode as CN try: - _ignore = unicode + _ignore = unicode # noqa: F821 PY2 = True except Exception as _ignore: PY2 = False @@ -119,7 +119,7 @@ def test_merge_cfg_from_cfg(self): cfg2 = CN() cfg2.A_UNICODE_KEY = b"bar" cfg.merge_from_other_cfg(cfg2) - assert type(cfg.A_UNICODE_KEY) == unicode + assert type(cfg.A_UNICODE_KEY) == unicode # noqa: F821 assert cfg.A_UNICODE_KEY == u"bar" # Test: merge with invalid type @@ -169,6 +169,18 @@ def test_deprecated_key_from_list(self): with self.assertRaises(AttributeError): _ = cfg.MODEL.DILATION # noqa + def test_nonexistant_key_from_list(self): + cfg = get_cfg() + opts = ["MODEL.DOES_NOT_EXIST", "IGNORE"] + with self.assertRaises(AssertionError): + cfg.merge_from_list(opts) + + def test_load_cfg_invalid_type(self): + # FOO.BAR.QUUX will have type None, which is not allowed + cfg_string = "FOO:\n BAR:\n QUUX:" + with self.assertRaises(AssertionError): + yacs.config.load_cfg(cfg_string) + def test_deprecated_key_from_file(self): # You should see logger messages like: # "Deprecated config key (ignoring): MODEL.DILATION" @@ -214,4 +226,6 @@ def test_invalid_type(self): if __name__ == "__main__": logging.basicConfig() + yacs_logger = logging.getLogger("yacs.config") + yacs_logger.setLevel(logging.DEBUG) unittest.main()