Skip to content

Commit

Permalink
Merge pull request #84 from LSSTDESC/issue/83/stage_param
Browse files Browse the repository at this point in the history
Fix parse_command_line to deal with StageParameters in configs
  • Loading branch information
eacharles authored Oct 5, 2022
2 parents dc92ef1 + 011891a commit 3a4a149
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 6 deletions.
8 changes: 5 additions & 3 deletions ceci/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from abc import abstractmethod
from . import errors
from .monitor import MemoryMonitor
from .config import StageConfig, cast_to_streamable
from .config import StageParameter, StageConfig, cast_to_streamable

SERIAL = "serial"
MPI_PARALLEL = "mpi"
Expand Down Expand Up @@ -488,8 +488,10 @@ def parse_command_line(cls, cmd=None):
parser = argparse.ArgumentParser(description=f"Run pipeline stage {cls.name}")
parser.add_argument("stage_name")
for conf, def_val in cls.config_options.items():
opt_type = def_val if isinstance(def_val, type) else type(def_val)

if isinstance(def_val, StageParameter):
opt_type = def_val.dtype
else:
opt_type = def_val if isinstance(def_val, type) else type(def_val)
if opt_type == bool:
parser.add_argument(f"--{conf}", action="store_const", const=True)
parser.add_argument(
Expand Down
18 changes: 15 additions & 3 deletions tests/test_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,25 @@ class TestStage(PipelineStage):
name = "test_stage_param"
inputs = [("inp1", HDFFile)]
outputs = []
config_options = dict(a=StageParameter(float, 5., msg="a float"))
config_options = dict(
a=StageParameter(float, 5., msg="a float"),
b=StageParameter(str, msg="a str"),
)

def run(self):
pass

stage_1 = TestStage.make_stage(a=6., inp1='dummy')
stage_1 = TestStage.make_stage(
a=6., b='puffins are not extinct?', inp1='dummy',
)
assert stage_1.config.a == 6.
assert stage_1.config.b == 'puffins are not extinct?'

cmd = "TestStage", "--a", "6", "--b", "puffins are not extinct?", "--inp", "dummy"
stage_1_cmd = TestStage(TestStage.parse_command_line(cmd))
assert stage_1_cmd.config.a == 6.
assert stage_1_cmd.config.b == 'puffins are not extinct?'


# This one should not work
class TestStage_2(PipelineStage):
Expand Down Expand Up @@ -198,7 +210,7 @@ def run(self):





def test_incomplete():
class Alpha(PipelineStage):
Expand Down

0 comments on commit 3a4a149

Please sign in to comment.