From c8361140c936e3f33f373d2fe8658de0b5e00c7e Mon Sep 17 00:00:00 2001 From: Ross Girshick Date: Mon, 20 Aug 2018 11:34:11 -0700 Subject: [PATCH] Better conditional casting code --- .gitignore | 1 + setup.py | 2 +- tests/test_cfg.py => tests.py | 15 +++++++++ tests/__init__.py | 0 tox.ini | 2 +- yacs/config.py | 58 ++++++++++++++++++++++------------- 6 files changed, 55 insertions(+), 23 deletions(-) rename tests/test_cfg.py => tests.py (93%) delete mode 100644 tests/__init__.py diff --git a/.gitignore b/.gitignore index 85d75ab..a239758 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ yacs.egg-info yacs/__pycache__ example/__pycache__ .tox +**/*.pyc diff --git a/setup.py b/setup.py index a13787d..1083a76 100644 --- a/setup.py +++ b/setup.py @@ -6,6 +6,6 @@ author="Ross Girshick", author_email="ross.girshick@gmail.com", description="Yet Another Configuration System", - packages=["yacs", "tests"], + packages=["yacs"], long_description="A simple configuration system for research", ) diff --git a/tests/test_cfg.py b/tests.py similarity index 93% rename from tests/test_cfg.py rename to tests.py index f6bda8e..80ac0c0 100644 --- a/tests/test_cfg.py +++ b/tests.py @@ -5,6 +5,12 @@ import yacs.config from yacs.config import CfgNode as CN +try: + _ignore = unicode + PY2 = True +except Exception as _ignore: + PY2 = False + def get_cfg(): cfg = CN() @@ -107,6 +113,15 @@ def test_merge_cfg_from_cfg(self): assert type(cfg.TRAIN.SCALES) is tuple assert cfg.TRAIN.SCALES[0] == 1 + # Test str (bytes) <-> unicode conversion for py2 + if PY2: + cfg.A_UNICODE_KEY = u"foo" + cfg2 = CN() + cfg2.A_UNICODE_KEY = b"bar" + cfg.merge_from_other_cfg(cfg2) + assert type(cfg.A_UNICODE_KEY) == unicode + assert cfg.A_UNICODE_KEY == u"bar" + # Test: merge with invalid type cfg2 = CN() cfg2.TRAIN = CN() diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tox.ini b/tox.ini index 3a03670..cef851f 100644 --- a/tox.ini +++ b/tox.ini @@ -5,5 +5,5 @@ envlist = {py27,py36}-pyyaml{3,4} deps = pyyaml3: PyYAML>=3,<4 pyyaml4: PyYAML==4.2b4 -commands = python tests/test_cfg.py +commands = python tests.py diff --git a/yacs/config.py b/yacs/config.py index 2e06234..daaf2da 100644 --- a/yacs/config.py +++ b/yacs/config.py @@ -26,7 +26,7 @@ # py2 allow for str and unicode try: _VALID_TYPES = _VALID_TYPES.union({unicode}) -except Exception as e: +except Exception as _ignore: pass @@ -317,27 +317,43 @@ def _decode_cfg_value(v): return v -def _check_and_coerce_cfg_value_type(value_a, value_b, key, full_key): - """Checks that `value_a`, which is intended to replace `value_b` is of the - right type. The type is correct if it matches exactly or is one of a few +def _check_and_coerce_cfg_value_type(replacement, original, key, full_key): + """Checks that `replacement`, which is intended to replace `original` is of + the right type. The type is correct if it matches exactly or is one of a few cases in which the type can be easily coerced. """ + original_type = type(original) + replacement_type = type(replacement) + # The types must match (with some exceptions) - type_b = type(value_b) - type_a = type(value_a) - if type_a is type_b: - return value_a - - # Exceptions: numpy arrays, strings, tuple<->list - if isinstance(value_b, str): - value_a = str(value_a) - elif isinstance(value_a, tuple) and isinstance(value_b, list): - value_a = list(value_a) - elif isinstance(value_a, list) and isinstance(value_b, tuple): - value_a = tuple(value_a) - else: - raise ValueError( - "Type mismatch ({} vs. {}) with values ({} vs. {}) for config " - "key: {}".format(type_b, type_a, value_b, value_a, full_key) + if replacement_type == original_type: + return replacement + + # Cast replacement from from_type to to_type if the replacement and original + # types match from_type and to_type + def conditional_cast(from_type, to_type): + if replacement_type == from_type and original_type == to_type: + return True, to_type(replacement) + else: + return False, None + + # Conditionally casts + # list <-> tuple + casts = [(tuple, list), (list, tuple)] + # For py2: allow converting from str (bytes) to a unicode string + try: + casts.append((str, unicode)) + except Exception as _ignore: + pass + + for (from_type, to_type) in casts: + converted, converted_value = conditional_cast(from_type, to_type) + if converted: + return converted_value + + raise ValueError( + "Type mismatch ({} vs. {}) with values ({} vs. {}) for config " + "key: {}".format( + original_type, replacement_type, original, replacement, full_key ) - return value_a + )