From 17eb63ff4712ea78f93a9cdcc63cd8dfcee65244 Mon Sep 17 00:00:00 2001 From: tatianag Date: Mon, 25 Mar 2024 09:01:29 -0700 Subject: [PATCH] Add Input connection with deferred binding. This would allow an input to have the same dataset type as an output. --- doc/changes/DM-43572.feature.rst | 2 + python/lsst/pipe/base/connectionTypes.py | 9 ++ python/lsst/pipe/base/connections.py | 96 +++++++++---------- .../lsst/pipe/base/pipeline_graph/_tasks.py | 1 + 4 files changed, 55 insertions(+), 53 deletions(-) create mode 100644 doc/changes/DM-43572.feature.rst diff --git a/doc/changes/DM-43572.feature.rst b/doc/changes/DM-43572.feature.rst new file mode 100644 index 000000000..fb078d465 --- /dev/null +++ b/doc/changes/DM-43572.feature.rst @@ -0,0 +1,2 @@ +Added ``deferBinding`` attribute to ``Input`` connection, which allows us +to have an input connection with the same dataset type as an output. \ No newline at end of file diff --git a/python/lsst/pipe/base/connectionTypes.py b/python/lsst/pipe/base/connectionTypes.py index ad378222a..9c186cad6 100644 --- a/python/lsst/pipe/base/connectionTypes.py +++ b/python/lsst/pipe/base/connectionTypes.py @@ -298,6 +298,13 @@ class Input(BaseInput): spatial overlaps. This option has no effect when the connection is not an overall input of the pipeline (or subset thereof) for which a graph is being created, and it never affects the ordering of quanta. + deferBinding : `bool`, optional + If `True`, the dataset will not be automatically included in + the pipeline graph, ``deferGraphConstraint`` is implied. + The custom QuantumGraphBuilder is required to bind it and add a + corresponding edge to the pipeline graph. + This option allows to have the same dataset type as both + input and output of a quantum. Raises ------ @@ -310,6 +317,8 @@ class Input(BaseInput): deferGraphConstraint: bool = False + deferBinding: bool = False + _connection_type_set: ClassVar[str] = "inputs" def __post_init__(self) -> None: diff --git a/python/lsst/pipe/base/connections.py b/python/lsst/pipe/base/connections.py index 83395d373..34ebd2250 100644 --- a/python/lsst/pipe/base/connections.py +++ b/python/lsst/pipe/base/connections.py @@ -759,60 +759,50 @@ def buildDatasetRefs( """ inputDatasetRefs = InputQuantizedConnection() outputDatasetRefs = OutputQuantizedConnection() - # operate on a reference object and an iterable of names of class - # connection attributes - for refs, names in zip( - (inputDatasetRefs, outputDatasetRefs), - (itertools.chain(self.inputs, self.prerequisiteInputs), self.outputs), - strict=True, - ): - # get a name of a class connection attribute - for attributeName in names: - # get the attribute identified by name - attribute = getattr(self, attributeName) - # Branch if the attribute dataset type is an input - if attribute.name in quantum.inputs: - # if the dataset is marked to load deferred, wrap it in a - # DeferredDatasetRef - quantumInputRefs: list[DatasetRef] | list[DeferredDatasetRef] - if attribute.deferLoad: - quantumInputRefs = [ - DeferredDatasetRef(datasetRef=ref) for ref in quantum.inputs[attribute.name] - ] - else: - quantumInputRefs = list(quantum.inputs[attribute.name]) - # Unpack arguments that are not marked multiples (list of - # length one) - if not attribute.multiple: - if len(quantumInputRefs) > 1: - raise ScalarError( - "Received multiple datasets " - f"{', '.join(str(r.dataId) for r in quantumInputRefs)} " - f"for scalar connection {attributeName} " - f"({quantumInputRefs[0].datasetType.name}) " - f"of quantum for {quantum.taskName} with data ID {quantum.dataId}." - ) - if len(quantumInputRefs) == 0: - continue - setattr(refs, attributeName, quantumInputRefs[0]) - else: - # Add to the QuantizedConnection identifier - setattr(refs, attributeName, quantumInputRefs) - # Branch if the attribute dataset type is an output - elif attribute.name in quantum.outputs: - value = quantum.outputs[attribute.name] - # Unpack arguments that are not marked multiples (list of - # length one) - if not attribute.multiple: - setattr(refs, attributeName, value[0]) - else: - setattr(refs, attributeName, value) - # Specified attribute is not in inputs or outputs dont know how - # to handle, throw - else: - raise ValueError( - f"Attribute with name {attributeName} has no counterpart in input quantum" + + # populate inputDatasetRefs from quantum inputs + for attributeName in itertools.chain(self.inputs, self.prerequisiteInputs): + # get the attribute identified by name + attribute = getattr(self, attributeName) + # if the dataset is marked to load deferred, wrap it in a + # DeferredDatasetRef + quantumInputRefs: list[DatasetRef] | list[DeferredDatasetRef] + if attribute.deferLoad: + quantumInputRefs = [ + DeferredDatasetRef(datasetRef=ref) for ref in quantum.inputs[attribute.name] + ] + else: + quantumInputRefs = list(quantum.inputs[attribute.name]) + # Unpack arguments that are not marked multiples (list of + # length one) + if not attribute.multiple: + if len(quantumInputRefs) > 1: + raise ScalarError( + "Received multiple datasets " + f"{', '.join(str(r.dataId) for r in quantumInputRefs)} " + f"for scalar connection {attributeName} " + f"({quantumInputRefs[0].datasetType.name}) " + f"of quantum for {quantum.taskName} with data ID {quantum.dataId}." ) + if len(quantumInputRefs) == 0: + continue + setattr(inputDatasetRefs, attributeName, quantumInputRefs[0]) + else: + # Add to the QuantizedConnection identifier + setattr(inputDatasetRefs, attributeName, quantumInputRefs) + + # populate outputDatasetRefs from quantum outputs + for attributeName in self.outputs: + # get the attribute identified by name + attribute = getattr(self, attributeName) + value = quantum.outputs[attribute.name] + # Unpack arguments that are not marked multiples (list of + # length one) + if not attribute.multiple: + setattr(outputDatasetRefs, attributeName, value[0]) + else: + setattr(outputDatasetRefs, attributeName, value) + return inputDatasetRefs, outputDatasetRefs def adjustQuantum( diff --git a/python/lsst/pipe/base/pipeline_graph/_tasks.py b/python/lsst/pipe/base/pipeline_graph/_tasks.py index 945706fd3..e8eb1bb0c 100644 --- a/python/lsst/pipe/base/pipeline_graph/_tasks.py +++ b/python/lsst/pipe/base/pipeline_graph/_tasks.py @@ -527,6 +527,7 @@ def _from_imported_data( inputs = { name: ReadEdge._from_connection_map(key, name, data.connection_map) for name in data.connections.inputs + if not getattr(data.connections, name).deferBinding } init_outputs = { name: WriteEdge._from_connection_map(init_key, name, data.connection_map)