Skip to content

Commit

Permalink
Add functionality to allow addition of new value from yaml (#9)
Browse files Browse the repository at this point in the history
* Add additional parameter to CfgNode to allow new values

* Format config files for test properly

* Allow addition of new CfgNodes from yaml file

* Format config properly

* Simplify logic for merging configurations
  • Loading branch information
Rizhiy authored and rbgirshick committed Dec 10, 2018
1 parent a4bba08 commit 6c41d01
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 18 deletions.
6 changes: 6 additions & 0 deletions example/config_new_allowed.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
KWARGS:
a: 1 # Test adding of basic value
B:
c: 2 # Test adding of another node
D:
e: '3' # Test adding of nested nodes
3 changes: 3 additions & 0 deletions example/config_new_allowed_bad.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
KWARGS:
Y:
f: 4 # While `KWARGS` allows new nodes, `KWARGS.Y` doesn't, so this should raise an Exception
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name="yacs",
version="0.1.4",
version="0.1.5",
author="Ross Girshick",
author_email="[email protected]",
description="Yet Another Configuration System",
Expand Down
42 changes: 25 additions & 17 deletions yacs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@

import yaml


# Flag for py2 and py3 compatibility to use when separate code paths are necessary
# When _PY2 is False, we assume Python 3 is in use
_PY2 = False
Expand Down Expand Up @@ -70,8 +69,9 @@ class CfgNode(dict):
IMMUTABLE = "__immutable__"
DEPRECATED_KEYS = "__deprecated_keys__"
RENAMED_KEYS = "__renamed_keys__"
NEW_ALLOWED = "__new_allowed__"

def __init__(self, init_dict=None, key_list=None):
def __init__(self, init_dict=None, key_list=None, new_allowed=False):
# Recursively convert nested dictionaries in init_dict into CfgNodes
init_dict = {} if init_dict is None else init_dict
key_list = [] if key_list is None else key_list
Expand Down Expand Up @@ -108,6 +108,9 @@ def __init__(self, init_dict=None, key_list=None):
# ),
}

# Allow new attributes after initialisation
self.__dict__[CfgNode.NEW_ALLOWED] = new_allowed

def __getattr__(self, name):
if name in self:
return self[name]
Expand Down Expand Up @@ -280,6 +283,9 @@ def raise_key_rename_error(self, full_key):
)
)

def is_new_allowed(self):
return self.__dict__[CfgNode.NEW_ALLOWED]


def load_cfg(cfg_file_obj_or_str):
"""Load a cfg. Supports loading from:
Expand Down Expand Up @@ -382,28 +388,30 @@ def _merge_a_into_b(a, b, root, key_list):

for k, v_ in a.items():
full_key = ".".join(key_list + [k])
# a must specify keys that are in b
if k not in b:

v = copy.deepcopy(v_)
v = _decode_cfg_value(v)

if k in b:
v = _check_and_coerce_cfg_value_type(v, b[k], k, full_key)
# Recursively merge dicts
if isinstance(v, CfgNode):
try:
_merge_a_into_b(v, b[k], root, key_list + [k])
except BaseException:
raise
else:
b[k] = v
elif b.is_new_allowed():
b[k] = v
else:
if root.key_is_deprecated(full_key):
continue
elif root.key_is_renamed(full_key):
root.raise_key_rename_error(full_key)
else:
raise KeyError("Non-existent config key: {}".format(full_key))

v = copy.deepcopy(v_)
v = _decode_cfg_value(v)
v = _check_and_coerce_cfg_value_type(v, b[k], k, full_key)

# Recursively merge dicts
if isinstance(v, CfgNode):
try:
_merge_a_into_b(v, b[k], root, key_list + [k])
except BaseException:
raise
else:
b[k] = v


def _decode_cfg_value(v):
"""Decodes a raw config value (e.g., from a yaml config files or command
Expand Down
21 changes: 21 additions & 0 deletions yacs/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ def get_cfg():
message="Please update your config fil config file.",
)

cfg.KWARGS = CN(new_allowed=True)
cfg.KWARGS.z = 0
cfg.KWARGS.Y = CN()
cfg.KWARGS.Y.X = 1

return cfg


Expand Down Expand Up @@ -254,6 +259,10 @@ def test_invalid_type(self):

def test__str__(self):
expected_str = """
KWARGS:
Y:
X: 1
z: 0
MODEL:
TYPE: a_foo_model
NUM_GPUS: 8
Expand All @@ -273,6 +282,18 @@ def test__str__(self):
cfg = get_cfg()
assert str(cfg) == expected_str

def test_new_allowed(self):
cfg = get_cfg()
cfg.merge_from_file("example/config_new_allowed.yaml")
assert cfg.KWARGS.a == 1
assert cfg.KWARGS.B.c == 2
assert cfg.KWARGS.B.D.e == '3'

def test_new_allowed_bad(self):
cfg = get_cfg()
with self.assertRaises(KeyError):
cfg.merge_from_file("example/config_new_allowed_bad.yaml")


if __name__ == "__main__":
logging.basicConfig()
Expand Down

0 comments on commit 6c41d01

Please sign in to comment.