Skip to content

Commit

Permalink
Merge pull request #308 from lsst/tickets/DM-38041
Browse files Browse the repository at this point in the history
DM-38041: delegate to new pre-exec-init implementations in pipe_base
  • Loading branch information
TallJimbo authored Sep 12, 2024
2 parents 32219db + 3226412 commit 12378ff
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 300 deletions.
40 changes: 2 additions & 38 deletions python/lsst/ctrl/mpexec/cmdLineFwk.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
Butler,
CollectionType,
Config,
DatasetId,
DatasetType,
DimensionConfig,
DimensionUniverse,
Expand All @@ -58,7 +57,6 @@
Registry,
)
from lsst.daf.butler.datastore.cache_manager import DatastoreCacheManager
from lsst.daf.butler.datastore.record_data import DatastoreRecordData
from lsst.daf.butler.direct_butler import DirectButler
from lsst.daf.butler.registry import MissingCollectionError, RegistryDefaults
from lsst.daf.butler.registry.wildcards import CollectionWildcard
Expand Down Expand Up @@ -956,42 +954,8 @@ def preExecInitQBB(self, task_factory: TaskFactory, args: SimpleNamespace) -> No
# but we need datastore records for initInputs, and those are only
# available from Quanta, so load the whole thing.
qgraph = QuantumGraph.loadUri(args.qgraph, graphID=args.qgraph_id)
universe = qgraph.universe

# Collect all init input/output dataset IDs.
predicted_inputs: set[DatasetId] = set()
predicted_outputs: set[DatasetId] = set()
for taskDef in qgraph.iterTaskGraph():
if (refs := qgraph.initInputRefs(taskDef)) is not None:
predicted_inputs.update(ref.id for ref in refs)
if (refs := qgraph.initOutputRefs(taskDef)) is not None:
predicted_outputs.update(ref.id for ref in refs)
predicted_outputs.update(ref.id for ref in qgraph.globalInitOutputRefs())
# remove intermediates from inputs
predicted_inputs -= predicted_outputs

# Very inefficient way to extract datastore records from quantum graph,
# we have to scan all quanta and look at their datastore records.
datastore_records: dict[str, DatastoreRecordData] = {}
for quantum_node in qgraph:
for store_name, records in quantum_node.quantum.datastore_records.items():
subset = records.subset(predicted_inputs)
if subset is not None:
datastore_records.setdefault(store_name, DatastoreRecordData()).update(subset)

dataset_types = {dstype.name: dstype for dstype in qgraph.registryDatasetTypes()}

# Make butler from everything.
butler = QuantumBackedButler.from_predicted(
config=args.butler_config,
predicted_inputs=predicted_inputs,
predicted_outputs=predicted_outputs,
dimensions=universe,
datastore_records=datastore_records,
search_paths=args.config_search_path,
dataset_types=dataset_types,
)

# Make QBB.
butler = qgraph.make_init_qbb(args.butler_config, config_search_paths=args.config_search_path)
# Save all InitOutputs, configs, etc.
preExecInit = PreExecInitLimited(butler, task_factory)
preExecInit.initialize(qgraph)
Expand Down
276 changes: 15 additions & 261 deletions python/lsst/ctrl/mpexec/preExecInit.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,57 +34,19 @@
# -------------------------------
import abc
import logging
from collections.abc import Iterable, Iterator
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING

# -----------------------------
# Imports for other modules --
# -----------------------------
from lsst.daf.butler import DatasetRef, DatasetType
from lsst.daf.butler.registry import ConflictingDefinitionError, MissingDatasetTypeError
from lsst.pipe.base.automatic_connection_constants import (
PACKAGES_INIT_OUTPUT_NAME,
PACKAGES_INIT_OUTPUT_STORAGE_CLASS,
)
from lsst.utils.packages import Packages

if TYPE_CHECKING:
from lsst.daf.butler import Butler, LimitedButler
from lsst.pipe.base import QuantumGraph, TaskDef, TaskFactory
from lsst.pipe.base import QuantumGraph, TaskFactory

_LOG = logging.getLogger(__name__)


class MissingReferenceError(Exception):
"""Exception raised when resolved reference is missing from graph."""

pass


def _compare_packages(old_packages: Packages, new_packages: Packages) -> None:
"""Compare two versions of Packages.
Parameters
----------
old_packages : `Packages`
Previously recorded package versions.
new_packages : `Packages`
New set of package versions.
Raises
------
TypeError
Raised if parameters are inconsistent.
"""
diff = new_packages.difference(old_packages)
if diff:
versions_str = "; ".join(f"{pkg}: {diff[pkg][1]} vs {diff[pkg][0]}" for pkg in diff)
raise TypeError(f"Package versions mismatch: ({versions_str})")
else:
_LOG.debug("new packages are consistent with old")


class PreExecInitBase(abc.ABC):
"""Common part of the implementation of PreExecInit classes that does not
depend on Butler type.
Expand All @@ -94,14 +56,13 @@ class PreExecInitBase(abc.ABC):
butler : `~lsst.daf.butler.LimitedButler`
Butler to use.
taskFactory : `lsst.pipe.base.TaskFactory`
Task factory.
Ignored and accepted for backwards compatibility.
extendRun : `bool`
Whether extend run parameter is in use.
"""

def __init__(self, butler: LimitedButler, taskFactory: TaskFactory, extendRun: bool):
self.butler = butler
self.taskFactory = taskFactory
self.extendRun = extendRun

def initialize(
Expand Down Expand Up @@ -183,36 +144,7 @@ def saveInitOutputs(self, graph: QuantumGraph) -> None:
new data.
"""
_LOG.debug("Will save InitOutputs for all tasks")
for taskDef in self._task_iter(graph):
init_input_refs = graph.initInputRefs(taskDef) or []
task = self.taskFactory.makeTask(
graph.pipeline_graph.tasks[taskDef.label], self.butler, init_input_refs
)
for name in taskDef.connections.initOutputs:
attribute = getattr(taskDef.connections, name)
init_output_refs = graph.initOutputRefs(taskDef) or []
init_output_ref, obj_from_store = self._find_dataset(init_output_refs, attribute.name)
if init_output_ref is None:
raise ValueError(f"Cannot find dataset reference for init output {name} in a graph")
init_output_var = getattr(task, name)

if obj_from_store is not None:
_LOG.debug(
"Retrieving InitOutputs for task=%s key=%s dsTypeName=%s", task, name, attribute.name
)
obj_from_store = self.butler.get(init_output_ref)
# Types are supposed to be identical.
# TODO: Check that object contents is identical too.
if type(obj_from_store) is not type(init_output_var):
raise TypeError(
f"Stored initOutput object type {type(obj_from_store)} "
"is different from task-generated type "
f"{type(init_output_var)} for task {taskDef}"
)
else:
_LOG.debug("Saving InitOutputs for task=%s key=%s", taskDef.label, name)
# This can still raise if there is a concurrent write.
self.butler.put(init_output_var, init_output_ref)
graph.write_init_outputs(self.butler, skip_existing=self.extendRun)

def saveConfigs(self, graph: QuantumGraph) -> None:
"""Write configurations for pipeline tasks to butler or check that
Expand All @@ -225,49 +157,13 @@ def saveConfigs(self, graph: QuantumGraph) -> None:
Raises
------
TypeError
Raised if existing object in butler is different from new data.
Exception
Raised if ``extendRun`` is `False` and datasets already exists.
Content of a butler collection should not be changed if exception
is raised.
ConflictingDefinitionError
Raised if existing object in butler is different from new data, or
if ``extendRun`` is `False` and datasets already exists.
Content of a butler collection should not be changed if this
exception is raised.
"""

def logConfigMismatch(msg: str) -> None:
"""Log messages about configuration mismatch.
Parameters
----------
msg : `str`
Log message to use.
"""
_LOG.fatal("Comparing configuration: %s", msg)

_LOG.debug("Will save Configs for all tasks")
# start transaction to rollback any changes on exceptions
with self.transaction():
for taskDef in self._task_iter(graph):
# Config dataset ref is stored in task init outputs, but it
# may be also be missing.
task_output_refs = graph.initOutputRefs(taskDef)
if task_output_refs is None:
continue

config_ref, old_config = self._find_dataset(task_output_refs, taskDef.configDatasetName)
if config_ref is None:
continue

if old_config is not None:
if not taskDef.config.compare(old_config, shortcut=False, output=logConfigMismatch):
raise TypeError(
f"Config does not match existing task config {taskDef.configDatasetName!r} in "
"butler; tasks configurations must be consistent within the same run collection"
)
else:
_LOG.debug(
"Saving Config for task=%s dataset type=%s", taskDef.label, taskDef.configDatasetName
)
self.butler.put(taskDef.config, config_ref)
graph.write_configs(self.butler, compare_existing=self.extendRun)

def savePackageVersions(self, graph: QuantumGraph) -> None:
"""Write versions of software packages to butler.
Expand All @@ -282,96 +178,7 @@ def savePackageVersions(self, graph: QuantumGraph) -> None:
TypeError
Raised if existing object in butler is incompatible with new data.
"""
packages = Packages.fromSystem()
_LOG.debug("want to save packages: %s", packages)

# start transaction to rollback any changes on exceptions
with self.transaction():
# Packages dataset ref is stored in graph's global init outputs,
# but it may be also be missing.

packages_ref, old_packages = self._find_dataset(
graph.globalInitOutputRefs(), PACKAGES_INIT_OUTPUT_NAME
)
if packages_ref is None:
return

if old_packages is not None:
# Note that because we can only detect python modules that have
# been imported, the stored list of products may be more or
# less complete than what we have now. What's important is
# that the products that are in common have the same version.
_compare_packages(old_packages, packages)
# Update the old set of packages in case we have more packages
# that haven't been persisted.
extra = packages.extra(old_packages)
if extra:
_LOG.debug("extra packages: %s", extra)
old_packages.update(packages)
# have to remove existing dataset first, butler has no
# replace option.
self.butler.pruneDatasets([packages_ref], unstore=True, purge=True)
self.butler.put(old_packages, packages_ref)
else:
self.butler.put(packages, packages_ref)

def _find_dataset(
self, refs: Iterable[DatasetRef], dataset_type: str
) -> tuple[DatasetRef | None, Any | None]:
"""Find a ref with a given dataset type name in a list of references
and try to retrieve its data from butler.
Parameters
----------
refs : `~collections.abc.Iterable` [ `~lsst.daf.butler.DatasetRef` ]
References to check for matching dataset type.
dataset_type : `str`
Name of a dataset type to look for.
Returns
-------
ref : `~lsst.daf.butler.DatasetRef` or `None`
Dataset reference or `None` if there is no matching dataset type.
data : `Any`
An existing object extracted from butler, `None` if ``ref`` is
`None` or if there is no existing object for that reference.
"""
ref: DatasetRef | None = None
for ref in refs:
if ref.datasetType.name == dataset_type:
break
else:
return None, None

try:
data = self.butler.get(ref)
if data is not None and not self.extendRun:
# It must not exist unless we are extending run.
raise ConflictingDefinitionError(f"Dataset {ref} already exists in butler")
except (LookupError, FileNotFoundError):
data = None
return ref, data

def _task_iter(self, graph: QuantumGraph) -> Iterator[TaskDef]:
"""Iterate over TaskDefs in a graph, return only tasks that have one or
more associated quanta.
"""
for taskDef in graph.iterTaskGraph():
if graph.getNumberOfQuantaForTask(taskDef) > 0:
yield taskDef

@contextmanager
def transaction(self) -> Iterator[None]:
"""Context manager for transaction.
Default implementation has no transaction support.
Yields
------
`None`
No transaction support.
"""
yield
graph.write_packages(self.butler, compare_existing=self.extendRun)


class PreExecInit(PreExecInitBase):
Expand Down Expand Up @@ -402,65 +209,12 @@ def __init__(self, butler: Butler, taskFactory: TaskFactory, extendRun: bool = F
"with a default output RUN collection."
)

@contextmanager
def transaction(self) -> Iterator[None]:
# dosctring inherited
with self.full_butler.transaction():
yield

def initializeDatasetTypes(self, graph: QuantumGraph, registerDatasetTypes: bool = False) -> None:
# docstring inherited
missing_dataset_types: set[str] = set()
dataset_types = [node.dataset_type for node in graph.pipeline_graph.dataset_types.values()]
dataset_types.append(
DatasetType(
PACKAGES_INIT_OUTPUT_NAME, self.butler.dimensions.empty, PACKAGES_INIT_OUTPUT_STORAGE_CLASS
)
)
for dataset_type in dataset_types:
# Resolving the PipelineGraph when building the QuantumGraph should
# have already guaranteed that this is the registry dataset type
# and that all references to it use compatible storage classes,
# so we don't need another check for compatibility here; if the
# dataset type doesn't match the registry that's already a problem.
if registerDatasetTypes:
_LOG.debug("Registering DatasetType %s with registry", dataset_type.name)
try:
self.full_butler.registry.registerDatasetType(dataset_type)
except ConflictingDefinitionError:
expected = self.full_butler.registry.getDatasetType(dataset_type.name)
raise ConflictingDefinitionError(
f"DatasetType definition in registry has changed since the QuantumGraph was built: "
f"{dataset_type} (graph) != {expected} (registry)."
)
else:
_LOG.debug("Checking DatasetType %s against registry", dataset_type.name)
try:
expected = self.full_butler.registry.getDatasetType(dataset_type.name)
except MissingDatasetTypeError:
# Likely means that --register-dataset-types is forgotten,
# but we could also get here if there is a prerequisite
# input that is optional and none were found in this repo;
# that is not an error. And we don't bother to check if
# they are optional here, since the fact that we were able
# to make the QG says that they were, since there couldn't
# have been any datasets if the dataset types weren't
# registered.
if not graph.pipeline_graph.dataset_types[dataset_type.name].is_prerequisite:
missing_dataset_types.add(dataset_type.name)
continue
if expected != dataset_type:
raise ConflictingDefinitionError(
f"DatasetType definition in registry has changed since the QuantumGraph was built: "
f"{dataset_type} (graph) != {expected} (registry)."
)
if missing_dataset_types:
plural = "s" if len(missing_dataset_types) != 1 else ""
raise MissingDatasetTypeError(
f"Missing dataset type definition{plural}: {', '.join(missing_dataset_types)}. "
"Dataset types have to be registered with either `butler register-dataset-type` or "
"passing `--register-dataset-types` option to `pipetask run`."
)
if registerDatasetTypes:
graph.pipeline_graph.register_dataset_types(self.full_butler)
else:
graph.pipeline_graph.check_dataset_type_registrations(self.full_butler)


class PreExecInitLimited(PreExecInitBase):
Expand Down
Loading

0 comments on commit 12378ff

Please sign in to comment.