Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-38041: delegate to new pre-exec-init implementations in pipe_base #308

Merged
merged 4 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 2 additions & 38 deletions python/lsst/ctrl/mpexec/cmdLineFwk.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
Butler,
CollectionType,
Config,
DatasetId,
DatasetType,
DimensionConfig,
DimensionUniverse,
Expand All @@ -58,7 +57,6 @@
Registry,
)
from lsst.daf.butler.datastore.cache_manager import DatastoreCacheManager
from lsst.daf.butler.datastore.record_data import DatastoreRecordData
from lsst.daf.butler.direct_butler import DirectButler
from lsst.daf.butler.registry import MissingCollectionError, RegistryDefaults
from lsst.daf.butler.registry.wildcards import CollectionWildcard
Expand Down Expand Up @@ -956,42 +954,8 @@ def preExecInitQBB(self, task_factory: TaskFactory, args: SimpleNamespace) -> No
# but we need datastore records for initInputs, and those are only
# available from Quanta, so load the whole thing.
qgraph = QuantumGraph.loadUri(args.qgraph, graphID=args.qgraph_id)
universe = qgraph.universe

# Collect all init input/output dataset IDs.
predicted_inputs: set[DatasetId] = set()
predicted_outputs: set[DatasetId] = set()
for taskDef in qgraph.iterTaskGraph():
if (refs := qgraph.initInputRefs(taskDef)) is not None:
predicted_inputs.update(ref.id for ref in refs)
if (refs := qgraph.initOutputRefs(taskDef)) is not None:
predicted_outputs.update(ref.id for ref in refs)
predicted_outputs.update(ref.id for ref in qgraph.globalInitOutputRefs())
# remove intermediates from inputs
predicted_inputs -= predicted_outputs

# Very inefficient way to extract datastore records from quantum graph,
# we have to scan all quanta and look at their datastore records.
datastore_records: dict[str, DatastoreRecordData] = {}
for quantum_node in qgraph:
for store_name, records in quantum_node.quantum.datastore_records.items():
subset = records.subset(predicted_inputs)
if subset is not None:
datastore_records.setdefault(store_name, DatastoreRecordData()).update(subset)

dataset_types = {dstype.name: dstype for dstype in qgraph.registryDatasetTypes()}

# Make butler from everything.
butler = QuantumBackedButler.from_predicted(
config=args.butler_config,
predicted_inputs=predicted_inputs,
predicted_outputs=predicted_outputs,
dimensions=universe,
datastore_records=datastore_records,
search_paths=args.config_search_path,
dataset_types=dataset_types,
)

# Make QBB.
butler = qgraph.make_init_qbb(args.butler_config, config_search_paths=args.config_search_path)
# Save all InitOutputs, configs, etc.
preExecInit = PreExecInitLimited(butler, task_factory)
preExecInit.initialize(qgraph)
Expand Down
276 changes: 15 additions & 261 deletions python/lsst/ctrl/mpexec/preExecInit.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,57 +34,19 @@
# -------------------------------
import abc
import logging
from collections.abc import Iterable, Iterator
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING

# -----------------------------
# Imports for other modules --
# -----------------------------
from lsst.daf.butler import DatasetRef, DatasetType
from lsst.daf.butler.registry import ConflictingDefinitionError, MissingDatasetTypeError
from lsst.pipe.base.automatic_connection_constants import (
PACKAGES_INIT_OUTPUT_NAME,
PACKAGES_INIT_OUTPUT_STORAGE_CLASS,
)
from lsst.utils.packages import Packages

if TYPE_CHECKING:
from lsst.daf.butler import Butler, LimitedButler
from lsst.pipe.base import QuantumGraph, TaskDef, TaskFactory
from lsst.pipe.base import QuantumGraph, TaskFactory

_LOG = logging.getLogger(__name__)


class MissingReferenceError(Exception):
"""Exception raised when resolved reference is missing from graph."""

pass


def _compare_packages(old_packages: Packages, new_packages: Packages) -> None:
"""Compare two versions of Packages.

Parameters
----------
old_packages : `Packages`
Previously recorded package versions.
new_packages : `Packages`
New set of package versions.

Raises
------
TypeError
Raised if parameters are inconsistent.
"""
diff = new_packages.difference(old_packages)
if diff:
versions_str = "; ".join(f"{pkg}: {diff[pkg][1]} vs {diff[pkg][0]}" for pkg in diff)
raise TypeError(f"Package versions mismatch: ({versions_str})")
else:
_LOG.debug("new packages are consistent with old")


class PreExecInitBase(abc.ABC):
"""Common part of the implementation of PreExecInit classes that does not
depend on Butler type.
Expand All @@ -94,14 +56,13 @@ class PreExecInitBase(abc.ABC):
butler : `~lsst.daf.butler.LimitedButler`
Butler to use.
taskFactory : `lsst.pipe.base.TaskFactory`
Task factory.
Ignored and accepted for backwards compatibility.
extendRun : `bool`
Whether extend run parameter is in use.
"""

def __init__(self, butler: LimitedButler, taskFactory: TaskFactory, extendRun: bool):
self.butler = butler
self.taskFactory = taskFactory
self.extendRun = extendRun

def initialize(
Expand Down Expand Up @@ -183,36 +144,7 @@ def saveInitOutputs(self, graph: QuantumGraph) -> None:
new data.
"""
_LOG.debug("Will save InitOutputs for all tasks")
for taskDef in self._task_iter(graph):
init_input_refs = graph.initInputRefs(taskDef) or []
task = self.taskFactory.makeTask(
graph.pipeline_graph.tasks[taskDef.label], self.butler, init_input_refs
)
for name in taskDef.connections.initOutputs:
attribute = getattr(taskDef.connections, name)
init_output_refs = graph.initOutputRefs(taskDef) or []
init_output_ref, obj_from_store = self._find_dataset(init_output_refs, attribute.name)
if init_output_ref is None:
raise ValueError(f"Cannot find dataset reference for init output {name} in a graph")
init_output_var = getattr(task, name)

if obj_from_store is not None:
_LOG.debug(
"Retrieving InitOutputs for task=%s key=%s dsTypeName=%s", task, name, attribute.name
)
obj_from_store = self.butler.get(init_output_ref)
# Types are supposed to be identical.
# TODO: Check that object contents is identical too.
if type(obj_from_store) is not type(init_output_var):
raise TypeError(
f"Stored initOutput object type {type(obj_from_store)} "
"is different from task-generated type "
f"{type(init_output_var)} for task {taskDef}"
)
else:
_LOG.debug("Saving InitOutputs for task=%s key=%s", taskDef.label, name)
# This can still raise if there is a concurrent write.
self.butler.put(init_output_var, init_output_ref)
graph.write_init_outputs(self.butler, skip_existing=self.extendRun)

def saveConfigs(self, graph: QuantumGraph) -> None:
"""Write configurations for pipeline tasks to butler or check that
Expand All @@ -225,49 +157,13 @@ def saveConfigs(self, graph: QuantumGraph) -> None:

Raises
------
TypeError
Raised if existing object in butler is different from new data.
Exception
Raised if ``extendRun`` is `False` and datasets already exists.
Content of a butler collection should not be changed if exception
is raised.
ConflictingDefinitionError
Raised if existing object in butler is different from new data, or
if ``extendRun`` is `False` and datasets already exists.
Content of a butler collection should not be changed if this
exception is raised.
"""

def logConfigMismatch(msg: str) -> None:
"""Log messages about configuration mismatch.

Parameters
----------
msg : `str`
Log message to use.
"""
_LOG.fatal("Comparing configuration: %s", msg)

_LOG.debug("Will save Configs for all tasks")
# start transaction to rollback any changes on exceptions
with self.transaction():
for taskDef in self._task_iter(graph):
# Config dataset ref is stored in task init outputs, but it
# may be also be missing.
task_output_refs = graph.initOutputRefs(taskDef)
if task_output_refs is None:
continue

config_ref, old_config = self._find_dataset(task_output_refs, taskDef.configDatasetName)
if config_ref is None:
continue

if old_config is not None:
if not taskDef.config.compare(old_config, shortcut=False, output=logConfigMismatch):
raise TypeError(
f"Config does not match existing task config {taskDef.configDatasetName!r} in "
"butler; tasks configurations must be consistent within the same run collection"
)
else:
_LOG.debug(
"Saving Config for task=%s dataset type=%s", taskDef.label, taskDef.configDatasetName
)
self.butler.put(taskDef.config, config_ref)
graph.write_configs(self.butler, compare_existing=self.extendRun)

def savePackageVersions(self, graph: QuantumGraph) -> None:
"""Write versions of software packages to butler.
Expand All @@ -282,96 +178,7 @@ def savePackageVersions(self, graph: QuantumGraph) -> None:
TypeError
Raised if existing object in butler is incompatible with new data.
"""
packages = Packages.fromSystem()
_LOG.debug("want to save packages: %s", packages)

# start transaction to rollback any changes on exceptions
with self.transaction():
# Packages dataset ref is stored in graph's global init outputs,
# but it may be also be missing.

packages_ref, old_packages = self._find_dataset(
graph.globalInitOutputRefs(), PACKAGES_INIT_OUTPUT_NAME
)
if packages_ref is None:
return

if old_packages is not None:
# Note that because we can only detect python modules that have
# been imported, the stored list of products may be more or
# less complete than what we have now. What's important is
# that the products that are in common have the same version.
_compare_packages(old_packages, packages)
# Update the old set of packages in case we have more packages
# that haven't been persisted.
extra = packages.extra(old_packages)
if extra:
_LOG.debug("extra packages: %s", extra)
old_packages.update(packages)
# have to remove existing dataset first, butler has no
# replace option.
self.butler.pruneDatasets([packages_ref], unstore=True, purge=True)
self.butler.put(old_packages, packages_ref)
else:
self.butler.put(packages, packages_ref)

def _find_dataset(
self, refs: Iterable[DatasetRef], dataset_type: str
) -> tuple[DatasetRef | None, Any | None]:
"""Find a ref with a given dataset type name in a list of references
and try to retrieve its data from butler.

Parameters
----------
refs : `~collections.abc.Iterable` [ `~lsst.daf.butler.DatasetRef` ]
References to check for matching dataset type.
dataset_type : `str`
Name of a dataset type to look for.

Returns
-------
ref : `~lsst.daf.butler.DatasetRef` or `None`
Dataset reference or `None` if there is no matching dataset type.
data : `Any`
An existing object extracted from butler, `None` if ``ref`` is
`None` or if there is no existing object for that reference.
"""
ref: DatasetRef | None = None
for ref in refs:
if ref.datasetType.name == dataset_type:
break
else:
return None, None

try:
data = self.butler.get(ref)
if data is not None and not self.extendRun:
# It must not exist unless we are extending run.
raise ConflictingDefinitionError(f"Dataset {ref} already exists in butler")
except (LookupError, FileNotFoundError):
data = None
return ref, data

def _task_iter(self, graph: QuantumGraph) -> Iterator[TaskDef]:
"""Iterate over TaskDefs in a graph, return only tasks that have one or
more associated quanta.
"""
for taskDef in graph.iterTaskGraph():
if graph.getNumberOfQuantaForTask(taskDef) > 0:
yield taskDef

@contextmanager
def transaction(self) -> Iterator[None]:
"""Context manager for transaction.

Default implementation has no transaction support.

Yields
------
`None`
No transaction support.
"""
yield
graph.write_packages(self.butler, compare_existing=self.extendRun)


class PreExecInit(PreExecInitBase):
Expand Down Expand Up @@ -402,65 +209,12 @@ def __init__(self, butler: Butler, taskFactory: TaskFactory, extendRun: bool = F
"with a default output RUN collection."
)

@contextmanager
def transaction(self) -> Iterator[None]:
# dosctring inherited
with self.full_butler.transaction():
yield

def initializeDatasetTypes(self, graph: QuantumGraph, registerDatasetTypes: bool = False) -> None:
# docstring inherited
missing_dataset_types: set[str] = set()
dataset_types = [node.dataset_type for node in graph.pipeline_graph.dataset_types.values()]
dataset_types.append(
DatasetType(
PACKAGES_INIT_OUTPUT_NAME, self.butler.dimensions.empty, PACKAGES_INIT_OUTPUT_STORAGE_CLASS
)
)
for dataset_type in dataset_types:
# Resolving the PipelineGraph when building the QuantumGraph should
# have already guaranteed that this is the registry dataset type
# and that all references to it use compatible storage classes,
# so we don't need another check for compatibility here; if the
# dataset type doesn't match the registry that's already a problem.
if registerDatasetTypes:
_LOG.debug("Registering DatasetType %s with registry", dataset_type.name)
try:
self.full_butler.registry.registerDatasetType(dataset_type)
except ConflictingDefinitionError:
expected = self.full_butler.registry.getDatasetType(dataset_type.name)
raise ConflictingDefinitionError(
f"DatasetType definition in registry has changed since the QuantumGraph was built: "
f"{dataset_type} (graph) != {expected} (registry)."
)
else:
_LOG.debug("Checking DatasetType %s against registry", dataset_type.name)
try:
expected = self.full_butler.registry.getDatasetType(dataset_type.name)
except MissingDatasetTypeError:
# Likely means that --register-dataset-types is forgotten,
# but we could also get here if there is a prerequisite
# input that is optional and none were found in this repo;
# that is not an error. And we don't bother to check if
# they are optional here, since the fact that we were able
# to make the QG says that they were, since there couldn't
# have been any datasets if the dataset types weren't
# registered.
if not graph.pipeline_graph.dataset_types[dataset_type.name].is_prerequisite:
missing_dataset_types.add(dataset_type.name)
continue
if expected != dataset_type:
raise ConflictingDefinitionError(
f"DatasetType definition in registry has changed since the QuantumGraph was built: "
f"{dataset_type} (graph) != {expected} (registry)."
)
if missing_dataset_types:
plural = "s" if len(missing_dataset_types) != 1 else ""
raise MissingDatasetTypeError(
f"Missing dataset type definition{plural}: {', '.join(missing_dataset_types)}. "
"Dataset types have to be registered with either `butler register-dataset-type` or "
"passing `--register-dataset-types` option to `pipetask run`."
)
if registerDatasetTypes:
graph.pipeline_graph.register_dataset_types(self.full_butler)
else:
graph.pipeline_graph.check_dataset_type_registrations(self.full_butler)


class PreExecInitLimited(PreExecInitBase):
Expand Down
Loading
Loading