Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix 2.0 bug with child class constructors #112

Merged
merged 4 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 5 additions & 12 deletions ceci/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,16 +185,7 @@ def build_stage_object(self, args):
self.stage_class = PipelineStage.get_stage(
self.class_name, self.module_name
)
# EAC. Ideally we would just pass the aliases into the c'tor of self.stage_class(),
# but that would requiring changing the signature of every sub-class, so we do this
# instead. At some point we might want to migrate to doing it the better way
try:
self.stage_obj = self.stage_class(args, aliases=self.aliases)
except TypeError:
self.stage_obj = self.stage_class(args)
self.stage_obj._aliases.update(**self.aliases)
self.stage_obj._io_checked = False
self.stage_obj.check_io()
self.stage_obj = self.stage_class(args, aliases=self.aliases)
return self.stage_obj

def generate_full_command(self, inputs, outputs, config):
Expand Down Expand Up @@ -674,9 +665,11 @@ def build_stage(self, stage_class, **kwargs):
`Pipeline.pipeline_files` data member, so that they are available to later stages
"""
kwcopy = kwargs.copy()
aliases = kwcopy.pop("aliases", {})
comm = kwcopy.pop("comm", None)
kwcopy.update(**self.pipeline_files)
aliases = kwcopy.pop("aliases", None)
stage = stage_class(kwcopy, aliases=aliases)

stage = stage_class(kwcopy, comm=comm, aliases=aliases)
return self.add_stage(stage)

def remove_stage(self, name):
Expand Down
12 changes: 4 additions & 8 deletions ceci/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import cProfile
import pdb
import datetime
import warnings

from abc import abstractmethod
from . import errors
Expand Down Expand Up @@ -102,6 +103,8 @@ def __init__(self, args, comm=None, aliases=None):
if comm is not None:
self.setup_mpi(comm)

self.check_io()


@classmethod
def make_stage(cls, **kwargs):
Expand All @@ -117,13 +120,7 @@ def make_stage(cls, **kwargs):
for output_ in cls.outputs: # pylint: disable=no-member
outtag = output_[0]
aliases[outtag] = f"{outtag}_{name}"
# EAC. Ideally we would just pass the aliases into the construction call
# but that would requiring changing the signature of every sub-class, so we do this
# instead. At some point we might want to migrate to doing it the better way
stage = cls(kwcopy, comm=comm)
stage._aliases.update(**aliases)
stage._io_checked = False
stage.check_io()
stage = cls(kwcopy, comm=comm, aliases=aliases)
return stage

def get_aliases(self):
Expand Down Expand Up @@ -181,7 +178,6 @@ def load_configs(self, args):
error_class = type(error)
msg = str(error)
raise error_class(f"Error configuring {self.instance_name}: {msg}")
self.check_io(args)

def check_io(self, args=None):
"""
Expand Down
9 changes: 7 additions & 2 deletions ceci/update_pipelines_for_ceci_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,20 @@ def update_pipeline_file_group(pipeline_files):
def scan_directory_and_update(base_dir):
groups = collections.defaultdict(list)
yaml = ruamel.yaml.YAML()
for dirpath, subdirs, filenames in os.walk(base_dir):
yaml.allow_duplicate_keys = True
for dirpath, _, filenames in os.walk(base_dir):
# just process yaml files
for filename in filenames:
if not (filename.endswith(".yaml") or filename.endswith(".yml")):
continue
filepath = os.path.join(dirpath, filename)
with open(filepath) as f:
yaml_str = f.read()
info = yaml.load(yaml_str)
try:
info = yaml.load(yaml_str)
except:
print("# Could not read yaml file:", filepath)
continue

if is_pipeline_file(info):
config = info["config"]
Expand Down
1 change: 0 additions & 1 deletion tests/test_interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,4 +137,3 @@ def test_inter_pipe():

if __name__ == "__main__":
test_config()
test_interactive()
2 changes: 0 additions & 2 deletions tests/test_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,8 +569,6 @@ class LimaSerial(LimaParallel):
assert LimaSerial.parse_command_line(["LimaSerial", "--mpi"]).mpi


# could add more tests here for constructor, but the regression tests here and in TXPipe are
# pretty thorough.

if __name__ == "__main__":
test_construct()
Expand Down
Loading