Skip to content

Commit

Permalink
Merge pull request #90 from LSSTDESC/get_output-aliasing
Browse files Browse the repository at this point in the history
Have get_input and get_output use aliases
  • Loading branch information
joezuntz authored May 11, 2023
2 parents d0b63f0 + 3d4d930 commit 90e11f5
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 9 deletions.
24 changes: 18 additions & 6 deletions ceci/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,15 +914,25 @@ def data_ranges_by_rank(self, n_rows, chunk_rows, parallel=True):
##################################################

def get_input(self, tag):
"""Return the path of an input file with the given tag"""
"""
Return the path of an input file with the given tag,
which can be aliased.
"""
tag = self.get_aliased_tag(tag)
return self._inputs[tag]



def get_output(self, tag, final_name=False):
"""Return the path of an output file with the given tag
"""
Return the path of an output file with the given tag,
which can be aliased already.
If final_name is False then use a temporary name - file will
be moved to its final name at the end
"""

tag = self.get_aliased_tag(tag)
path = self._outputs[tag]

# If not the final version, add a tag at the start of the filename
Expand All @@ -943,8 +953,7 @@ def open_input(self, tag, wrapper=False, **kwargs):
a more specific object - see the types.py file for more info.
"""
aliased_tag = self.get_aliased_tag(tag)
path = self.get_input(aliased_tag)
path = self.get_input(tag)
input_class = self.get_input_type(tag)
obj = input_class(path, "r", **kwargs)

Expand Down Expand Up @@ -984,8 +993,7 @@ def open_output(
Extra args are passed on to the file's class constructor.
"""
aliased_tag = self.get_aliased_tag(tag)
path = self.get_output(aliased_tag, final_name=final_name)
path = self.get_output(tag, final_name=final_name)
output_class = self.get_output_type(tag)

# HDF files can be opened for parallel writing
Expand Down Expand Up @@ -1055,14 +1063,18 @@ def input_tags(cls):

def get_input_type(self, tag):
"""Return the file type class of an input file with the given tag."""
tag = self.get_aliased_tag(tag)
for t, dt in self.inputs_():
t = self.get_aliased_tag(t)
if t == tag:
return dt
raise ValueError(f"Tag {tag} is not a known input") # pragma: no cover

def get_output_type(self, tag):
"""Return the file type class of an output file with the given tag."""
tag = self.get_aliased_tag(tag)
for t, dt in self.outputs_():
t = self.get_aliased_tag(t)
if t == tag:
return dt
raise ValueError(f"Tag {tag} is not a known output") # pragma: no cover
Expand Down
34 changes: 31 additions & 3 deletions tests/test_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,14 +430,19 @@ class India(PipelineStage):

print(ii.get_aliases())

# This currently works
# These should work with or without the alias
assert os.path.exists(ii.get_input("my_alias"))
assert os.path.exists(ii.get_input("my_input"))

# This works now
f = ii.open_input("my_input")
print(f.keys())
f.close()

f = ii.open_input("my_alias")
print(f.keys())
f.close()


def test_open_output():
class Juliett(PipelineStage):
Expand All @@ -453,16 +458,39 @@ class Juliett(PipelineStage):
print(f.keys())
f.close()

# Testing with an alias - config.yml defines an alias for my_input, my_alias
jj = Juliett.make_stage(name="JuliettCopy", aliases=dict(my_output='my_alias'))
# Testing with an alias
jj = Juliett.make_stage(aliases=dict(my_output='my_alias'))

print(jj.get_aliases())

assert jj.get_output("my_output") == jj.get_output("my_alias")

# This works now
f = jj.open_output("my_output")
print(f.keys())
f.close()

f = jj.open_output("my_alias")
print(f.keys())
f.close()

# Testing with a new name
jj = Juliett.make_stage(name="JuliettCopy")

print(jj.get_aliases())

assert jj.get_output("my_output") == jj.get_output("my_output_JuliettCopy")

# Check we can open using the original name
f = jj.open_output("my_output")
print(f.keys())
f.close()

# Check with an alias specified for the output name
f = jj.open_output("my_output_JuliettCopy")
print(f.keys())
f.close()


def core_test_map(comm):
size = comm.size if comm else 1
Expand Down

0 comments on commit 90e11f5

Please sign in to comment.