Skip to content

Commit

Permalink
Merge pull request #410 from lsst/tickets/DM-43572
Browse files Browse the repository at this point in the history
DM-43572: Input connection with deferred binding
  • Loading branch information
tgoldina authored Mar 29, 2024
2 parents 697e81c + 17eb63f commit 3034863
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 53 deletions.
2 changes: 2 additions & 0 deletions doc/changes/DM-43572.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Added ``deferBinding`` attribute to ``Input`` connection, which allows us
to have an input connection with the same dataset type as an output.
9 changes: 9 additions & 0 deletions python/lsst/pipe/base/connectionTypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,13 @@ class Input(BaseInput):
spatial overlaps. This option has no effect when the connection is not
an overall input of the pipeline (or subset thereof) for which a graph
is being created, and it never affects the ordering of quanta.
deferBinding : `bool`, optional
If `True`, the dataset will not be automatically included in
the pipeline graph, ``deferGraphConstraint`` is implied.
The custom QuantumGraphBuilder is required to bind it and add a
corresponding edge to the pipeline graph.
This option allows to have the same dataset type as both
input and output of a quantum.
Raises
------
Expand All @@ -310,6 +317,8 @@ class Input(BaseInput):

deferGraphConstraint: bool = False

deferBinding: bool = False

_connection_type_set: ClassVar[str] = "inputs"

def __post_init__(self) -> None:
Expand Down
96 changes: 43 additions & 53 deletions python/lsst/pipe/base/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,60 +759,50 @@ def buildDatasetRefs(
"""
inputDatasetRefs = InputQuantizedConnection()
outputDatasetRefs = OutputQuantizedConnection()
# operate on a reference object and an iterable of names of class
# connection attributes
for refs, names in zip(
(inputDatasetRefs, outputDatasetRefs),
(itertools.chain(self.inputs, self.prerequisiteInputs), self.outputs),
strict=True,
):
# get a name of a class connection attribute
for attributeName in names:
# get the attribute identified by name
attribute = getattr(self, attributeName)
# Branch if the attribute dataset type is an input
if attribute.name in quantum.inputs:
# if the dataset is marked to load deferred, wrap it in a
# DeferredDatasetRef
quantumInputRefs: list[DatasetRef] | list[DeferredDatasetRef]
if attribute.deferLoad:
quantumInputRefs = [
DeferredDatasetRef(datasetRef=ref) for ref in quantum.inputs[attribute.name]
]
else:
quantumInputRefs = list(quantum.inputs[attribute.name])
# Unpack arguments that are not marked multiples (list of
# length one)
if not attribute.multiple:
if len(quantumInputRefs) > 1:
raise ScalarError(
"Received multiple datasets "
f"{', '.join(str(r.dataId) for r in quantumInputRefs)} "
f"for scalar connection {attributeName} "
f"({quantumInputRefs[0].datasetType.name}) "
f"of quantum for {quantum.taskName} with data ID {quantum.dataId}."
)
if len(quantumInputRefs) == 0:
continue
setattr(refs, attributeName, quantumInputRefs[0])
else:
# Add to the QuantizedConnection identifier
setattr(refs, attributeName, quantumInputRefs)
# Branch if the attribute dataset type is an output
elif attribute.name in quantum.outputs:
value = quantum.outputs[attribute.name]
# Unpack arguments that are not marked multiples (list of
# length one)
if not attribute.multiple:
setattr(refs, attributeName, value[0])
else:
setattr(refs, attributeName, value)
# Specified attribute is not in inputs or outputs dont know how
# to handle, throw
else:
raise ValueError(
f"Attribute with name {attributeName} has no counterpart in input quantum"

# populate inputDatasetRefs from quantum inputs
for attributeName in itertools.chain(self.inputs, self.prerequisiteInputs):
# get the attribute identified by name
attribute = getattr(self, attributeName)
# if the dataset is marked to load deferred, wrap it in a
# DeferredDatasetRef
quantumInputRefs: list[DatasetRef] | list[DeferredDatasetRef]
if attribute.deferLoad:
quantumInputRefs = [
DeferredDatasetRef(datasetRef=ref) for ref in quantum.inputs[attribute.name]
]
else:
quantumInputRefs = list(quantum.inputs[attribute.name])
# Unpack arguments that are not marked multiples (list of
# length one)
if not attribute.multiple:
if len(quantumInputRefs) > 1:
raise ScalarError(
"Received multiple datasets "
f"{', '.join(str(r.dataId) for r in quantumInputRefs)} "
f"for scalar connection {attributeName} "
f"({quantumInputRefs[0].datasetType.name}) "
f"of quantum for {quantum.taskName} with data ID {quantum.dataId}."
)
if len(quantumInputRefs) == 0:
continue
setattr(inputDatasetRefs, attributeName, quantumInputRefs[0])
else:
# Add to the QuantizedConnection identifier
setattr(inputDatasetRefs, attributeName, quantumInputRefs)

# populate outputDatasetRefs from quantum outputs
for attributeName in self.outputs:
# get the attribute identified by name
attribute = getattr(self, attributeName)
value = quantum.outputs[attribute.name]
# Unpack arguments that are not marked multiples (list of
# length one)
if not attribute.multiple:
setattr(outputDatasetRefs, attributeName, value[0])
else:
setattr(outputDatasetRefs, attributeName, value)

return inputDatasetRefs, outputDatasetRefs

def adjustQuantum(
Expand Down
1 change: 1 addition & 0 deletions python/lsst/pipe/base/pipeline_graph/_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,7 @@ def _from_imported_data(
inputs = {
name: ReadEdge._from_connection_map(key, name, data.connection_map)
for name in data.connections.inputs
if not getattr(data.connections, name).deferBinding
}
init_outputs = {
name: WriteEdge._from_connection_map(init_key, name, data.connection_map)
Expand Down

0 comments on commit 3034863

Please sign in to comment.