diff --git a/ceci/stage.py b/ceci/stage.py index 39d8ddc..25a8a6b 100644 --- a/ceci/stage.py +++ b/ceci/stage.py @@ -490,6 +490,7 @@ def parse_command_line(cls, cmd=None): for conf, def_val in cls.config_options.items(): if isinstance(def_val, StageParameter): opt_type = def_val.dtype + def_val = def_val.default else: opt_type = def_val if isinstance(def_val, type) else type(def_val) if opt_type == bool: @@ -498,9 +499,12 @@ def parse_command_line(cls, cmd=None): f"--no-{conf}", dest=conf, action="store_const", const=False ) elif opt_type == list: - out_type = ( - def_val[0] if isinstance(def_val[0], type) else type(def_val[0]) - ) + if not def_val: + out_type = str + else: + out_type = ( + def_val[0] if isinstance(def_val[0], type) else type(def_val[0]) + ) if out_type is str: # pragma: no cover parser.add_argument( f"--{conf}", type=lambda string: string.split(",") diff --git a/tests/test_stage.py b/tests/test_stage.py index c0fe1ec..3bfdf6e 100644 --- a/tests/test_stage.py +++ b/tests/test_stage.py @@ -131,6 +131,8 @@ class TestStage(PipelineStage): config_options = dict( a=StageParameter(float, 5., msg="a float"), b=StageParameter(str, msg="a str"), + c=StageParameter(list, [1,2,3], msg="a list"), + d=StageParameter(list, [], msg="an empty list"), ) def run(self): @@ -141,7 +143,9 @@ def run(self): ) assert stage_1.config.a == 6. assert stage_1.config.b == 'puffins are not extinct?' - + assert 1 in stage_1.config.c + assert len(stage_1.config.d) == 0 + cmd = "TestStage", "--a", "6", "--b", "puffins are not extinct?", "--inp", "dummy" stage_1_cmd = TestStage(TestStage.parse_command_line(cmd)) assert stage_1_cmd.config.a == 6.