Skip to content

Commit

Permalink
Better conditional casting code
Browse files Browse the repository at this point in the history
  • Loading branch information
rbgirshick committed Aug 20, 2018
1 parent f5bb4f9 commit c836114
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 23 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ yacs.egg-info
yacs/__pycache__
example/__pycache__
.tox
**/*.pyc
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@
author="Ross Girshick",
author_email="[email protected]",
description="Yet Another Configuration System",
packages=["yacs", "tests"],
packages=["yacs"],
long_description="A simple configuration system for research",
)
15 changes: 15 additions & 0 deletions tests/test_cfg.py → tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@
import yacs.config
from yacs.config import CfgNode as CN

try:
_ignore = unicode
PY2 = True
except Exception as _ignore:
PY2 = False


def get_cfg():
cfg = CN()
Expand Down Expand Up @@ -107,6 +113,15 @@ def test_merge_cfg_from_cfg(self):
assert type(cfg.TRAIN.SCALES) is tuple
assert cfg.TRAIN.SCALES[0] == 1

# Test str (bytes) <-> unicode conversion for py2
if PY2:
cfg.A_UNICODE_KEY = u"foo"
cfg2 = CN()
cfg2.A_UNICODE_KEY = b"bar"
cfg.merge_from_other_cfg(cfg2)
assert type(cfg.A_UNICODE_KEY) == unicode
assert cfg.A_UNICODE_KEY == u"bar"

# Test: merge with invalid type
cfg2 = CN()
cfg2.TRAIN = CN()
Expand Down
Empty file removed tests/__init__.py
Empty file.
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ envlist = {py27,py36}-pyyaml{3,4}
deps =
pyyaml3: PyYAML>=3,<4
pyyaml4: PyYAML==4.2b4
commands = python tests/test_cfg.py
commands = python tests.py

58 changes: 37 additions & 21 deletions yacs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
# py2 allow for str and unicode
try:
_VALID_TYPES = _VALID_TYPES.union({unicode})
except Exception as e:
except Exception as _ignore:
pass


Expand Down Expand Up @@ -317,27 +317,43 @@ def _decode_cfg_value(v):
return v


def _check_and_coerce_cfg_value_type(value_a, value_b, key, full_key):
"""Checks that `value_a`, which is intended to replace `value_b` is of the
right type. The type is correct if it matches exactly or is one of a few
def _check_and_coerce_cfg_value_type(replacement, original, key, full_key):
"""Checks that `replacement`, which is intended to replace `original` is of
the right type. The type is correct if it matches exactly or is one of a few
cases in which the type can be easily coerced.
"""
original_type = type(original)
replacement_type = type(replacement)

# The types must match (with some exceptions)
type_b = type(value_b)
type_a = type(value_a)
if type_a is type_b:
return value_a

# Exceptions: numpy arrays, strings, tuple<->list
if isinstance(value_b, str):
value_a = str(value_a)
elif isinstance(value_a, tuple) and isinstance(value_b, list):
value_a = list(value_a)
elif isinstance(value_a, list) and isinstance(value_b, tuple):
value_a = tuple(value_a)
else:
raise ValueError(
"Type mismatch ({} vs. {}) with values ({} vs. {}) for config "
"key: {}".format(type_b, type_a, value_b, value_a, full_key)
if replacement_type == original_type:
return replacement

# Cast replacement from from_type to to_type if the replacement and original
# types match from_type and to_type
def conditional_cast(from_type, to_type):
if replacement_type == from_type and original_type == to_type:
return True, to_type(replacement)
else:
return False, None

# Conditionally casts
# list <-> tuple
casts = [(tuple, list), (list, tuple)]
# For py2: allow converting from str (bytes) to a unicode string
try:
casts.append((str, unicode))
except Exception as _ignore:
pass

for (from_type, to_type) in casts:
converted, converted_value = conditional_cast(from_type, to_type)
if converted:
return converted_value

raise ValueError(
"Type mismatch ({} vs. {}) with values ({} vs. {}) for config "
"key: {}".format(
original_type, replacement_type, original, replacement, full_key
)
return value_a
)

0 comments on commit c836114

Please sign in to comment.