From d262022bb29ac6e21d948201ed88be9a27b24416 Mon Sep 17 00:00:00 2001 From: Tim Jenness Date: Fri, 1 Nov 2024 16:31:43 -0700 Subject: [PATCH] Add new QuantumGraph.get_refs method --- python/lsst/pipe/base/graph/graph.py | 129 ++++++++++++++++++ .../script/retrieve_artifacts_for_quanta.py | 50 ++----- 2 files changed, 139 insertions(+), 40 deletions(-) diff --git a/python/lsst/pipe/base/graph/graph.py b/python/lsst/pipe/base/graph/graph.py index f25319d2..47a779dc 100644 --- a/python/lsst/pipe/base/graph/graph.py +++ b/python/lsst/pipe/base/graph/graph.py @@ -32,6 +32,7 @@ import getpass import io import json +import logging import lzma import os import struct @@ -75,6 +76,7 @@ from .quantumNode import BuildId, QuantumNode _T = TypeVar("_T", bound="QuantumGraph") +_LOG = logging.getLogger(__name__) # modify this constant any time the on disk representation of the save file # changes, and update the load helpers to behave properly for each version. @@ -1656,3 +1658,130 @@ def init_output_run(self, butler: LimitedButler, existing: bool = True) -> None: self.write_configs(butler, compare_existing=existing) self.write_packages(butler, compare_existing=existing) self.write_init_outputs(butler, skip_existing=existing) + + def get_refs( + self, + *, + include_init_inputs: bool = False, + include_inputs: bool = False, + include_intermediates: bool | None = None, + include_init_outputs: bool = False, + include_outputs: bool = False, + conform_outputs: bool = True, + ) -> tuple[set[DatasetRef], dict[str, DatastoreRecordData]]: + """Get the requested dataset refs from the graph. + + Parameters + ---------- + include_init_inputs : `bool`, optional + Include init inputs. + include_inputs : `bool`, optional + Include inputs. + include_intermediates : `bool` or `None`, optional + If `None`, no special handling for intermediates is performed. + If `True` intermediates are calculated even if other flags + do not request datasets. If `False` intermediates will be removed + from any results. + include_init_outputs : `bool`, optional + Include init outpus. + include_outputs : `bool`, optional + Include outputs. + conform_outputs : `bool`, optional + Whether any outputs found should have their dataset types conformed + with the registry dataset types. + + Returns + ------- + refs : `set` [ `lsst.daf.butler.DatasetRef` ] + The requested dataset refs found in the graph. + datastore_records : `dict` [ `str`, \ + `lsst.daf.butler.datastore.record_data.DatastoreRecordData` ] + Any datastore records found. + """ + datastore_records: dict[str, DatastoreRecordData] = {} + init_input_refs: set[DatasetRef] = set() + init_output_refs: set[DatasetRef] = set() + + if include_intermediates is True: + # Need to enable inputs and outputs even if not explicitly + # requested. + request_include_init_inputs = True + request_include_inputs = True + request_include_init_outputs = True + request_include_outputs = True + else: + request_include_init_inputs = include_init_inputs + request_include_inputs = include_inputs + request_include_init_outputs = include_init_outputs + request_include_outputs = include_outputs + + if request_include_init_inputs or request_include_init_outputs: + for task_def in self.iterTaskGraph(): + if request_include_init_inputs: + if in_refs := self.initInputRefs(task_def): + init_input_refs.update(in_refs) + if request_include_init_outputs: + if out_refs := self.initOutputRefs(task_def): + init_output_refs.update(out_refs) + + input_refs: set[DatasetRef] = set() + output_refs: set[DatasetRef] = set() + + for qnode in self: + if request_include_inputs: + for other_refs in qnode.quantum.inputs.values(): + input_refs.update(other_refs) + # Inputs can come with datastore records. + for store_name, records in qnode.quantum.datastore_records.items(): + datastore_records.setdefault(store_name, DatastoreRecordData()).update(records) + if request_include_outputs: + for other_refs in qnode.quantum.outputs.values(): + output_refs.update(other_refs) + + if conform_outputs: + # Get data repository definitions from the QuantumGraph; these can + # have different storage classes than those in the quanta. + dataset_types = {dstype.name: dstype for dstype in self.registryDatasetTypes()} + + def _update_ref(ref: DatasetRef) -> DatasetRef: + internal_dataset_type = dataset_types.get(ref.datasetType.name, ref.datasetType) + if internal_dataset_type.storageClass_name != ref.datasetType.storageClass_name: + ref = ref.overrideStorageClass(internal_dataset_type.storageClass_name) + return ref + + # Convert output_refs to the data repository storage classes, too. + output_refs = {_update_ref(ref) for ref in output_refs} + init_output_refs = {_update_ref(ref) for ref in init_output_refs} + + # Intermediates are the intersection of inputs and outputs. + intermediates = set() + if include_intermediates is False or include_intermediates: + intermediates = (input_refs | init_input_refs) & (output_refs & init_output_refs) + if include_intermediates is False: + # Remove intermediates from results. + init_input_refs -= intermediates + input_refs -= intermediates + init_output_refs -= intermediates + output_refs -= intermediates + intermediates = set() + + if not include_init_inputs: + init_input_refs = set() + if not include_inputs: + input_refs = set() + if not include_init_outputs: + init_output_refs = set() + if not include_outputs: + output_refs = set() + inter_msg = f"; Intermediates: {len(intermediates)}" if intermediates else "" + + _LOG.info( + "Found the following datasets. InitInputs: %d; Inputs: %d; InitOutputs: %s; Outputs: %d%s", + len(init_input_refs), + len(input_refs), + len(init_output_refs), + len(output_refs), + inter_msg, + ) + refs = input_refs | init_input_refs | init_output_refs | output_refs | intermediates + return refs, datastore_records diff --git a/python/lsst/pipe/base/script/retrieve_artifacts_for_quanta.py b/python/lsst/pipe/base/script/retrieve_artifacts_for_quanta.py index f576fe91..5801d86e 100644 --- a/python/lsst/pipe/base/script/retrieve_artifacts_for_quanta.py +++ b/python/lsst/pipe/base/script/retrieve_artifacts_for_quanta.py @@ -29,8 +29,7 @@ import logging -from lsst.daf.butler import DatasetRef, QuantumBackedButler -from lsst.daf.butler.datastore.record_data import DatastoreRecordData +from lsst.daf.butler import QuantumBackedButler from lsst.pipe.base import QuantumGraph from lsst.resources import ResourcePath @@ -81,48 +80,19 @@ def retrieve_artifacts_for_quanta( nodes = qgraph_node_id or None qgraph = QuantumGraph.loadUri(graph, nodes=nodes) + refs, datastore_records = qgraph.get_refs( + include_intermediates=True, + include_inputs=include_inputs, + include_init_inputs=include_inputs, + include_outputs=include_outputs, + include_init_outputs=include_outputs, + conform_outputs=True, # Need to look for predicted outputs with correct storage class. + ) + # Get data repository definitions from the QuantumGraph; these can have # different storage classes than those in the quanta. dataset_types = {dstype.name: dstype for dstype in qgraph.registryDatasetTypes()} - datastore_records: dict[str, DatastoreRecordData] = {} - refs: set[DatasetRef] = set() - if include_inputs: - # Collect input refs used by this graph. - for task_def in qgraph.iterTaskGraph(): - if in_refs := qgraph.initInputRefs(task_def): - refs.update(in_refs) - for qnode in qgraph: - for otherRefs in qnode.quantum.inputs.values(): - refs.update(otherRefs) - for store_name, records in qnode.quantum.datastore_records.items(): - datastore_records.setdefault(store_name, DatastoreRecordData()).update(records) - n_inputs = len(refs) - if n_inputs: - _LOG.info("Found %d input dataset%s.", n_inputs, "" if n_inputs == 1 else "s") - - if include_outputs: - # Collect output refs that could be created by this graph. - original_output_refs: set[DatasetRef] = set(qgraph.globalInitOutputRefs()) - for task_def in qgraph.iterTaskGraph(): - if out_refs := qgraph.initOutputRefs(task_def): - original_output_refs.update(out_refs) - for qnode in qgraph: - for otherRefs in qnode.quantum.outputs.values(): - original_output_refs.update(otherRefs) - - # Convert output_refs to the data repository storage classes, too. - for ref in original_output_refs: - internal_dataset_type = dataset_types.get(ref.datasetType.name, ref.datasetType) - if internal_dataset_type.storageClass_name != ref.datasetType.storageClass_name: - refs.add(ref.overrideStorageClass(internal_dataset_type.storageClass_name)) - else: - refs.add(ref) - - n_outputs = len(refs) - n_inputs - if n_outputs: - _LOG.info("Found %d output dataset%s.", n_outputs, "" if n_outputs == 1 else "s") - # Make QBB, its config is the same as output Butler. qbb = QuantumBackedButler.from_predicted( config=repo,