diff --git a/python/lsst/pipe/base/pipeline.py b/python/lsst/pipe/base/pipeline.py index f1bfe755b..1eee288b1 100644 --- a/python/lsst/pipe/base/pipeline.py +++ b/python/lsst/pipe/base/pipeline.py @@ -806,7 +806,11 @@ def _addConfigImpl(self, label: str, newConfig: pipelineIR.ConfigIR) -> None: return if label not in self._pipelineIR.tasks: raise LookupError(f"There are no tasks labeled '{label}' in the pipeline") - self._pipelineIR.tasks[label].add_or_update_config(newConfig) + match self._pipelineIR.tasks[label]: + case pipelineIR.TaskIR() as task: + task.add_or_update_config(newConfig) + case pipelineIR._AmbigousTask() as ambig_task: + ambig_task.tasks[-1].add_or_update_config(newConfig) def write_to_uri(self, uri: ResourcePathExpression) -> None: """Write the pipeline to a file or directory. @@ -839,6 +843,7 @@ def to_graph(self, registry: Registry | None = None) -> pipeline_graph.PipelineG graph : `pipeline_graph.PipelineGraph` Representation of the pipeline as a graph. """ + self._pipelineIR.resolve_task_ambiguity() instrument_class_name = self._pipelineIR.instrument data_id = {} if instrument_class_name is not None: @@ -906,7 +911,8 @@ def _add_task_to_graph(self, label: str, graph: pipeline_graph.PipelineGraph) -> """ if (taskIR := self._pipelineIR.tasks.get(label)) is None: raise NameError(f"Label {label} does not appear in this pipeline") - taskClass: type[PipelineTask] = doImportType(taskIR.klass) + # type ignore here because all ambiguity should be resolved + taskClass: type[PipelineTask] = doImportType(taskIR.klass) # type: ignore config = taskClass.ConfigClass() instrument: PipeBaseInstrument | None = None if (instrumentName := self._pipelineIR.instrument) is not None: @@ -915,7 +921,8 @@ def _add_task_to_graph(self, label: str, graph: pipeline_graph.PipelineGraph) -> config.applyConfigOverrides( instrument, getattr(taskClass, "_DefaultName", ""), - taskIR.config, + # type ignore here because all ambiguity should be resolved + taskIR.config, # type: ignore self._pipelineIR.parameters, label, ) @@ -940,6 +947,7 @@ def __getitem__(self, item: str) -> TaskDef: # Making a whole graph and then making a TaskDef from that is pretty # backwards, but I'm hoping to deprecate this method shortly in favor # of making the graph explicitly and working with its node objects. + self._pipelineIR.resolve_task_ambiguity() graph = pipeline_graph.PipelineGraph() self._add_task_to_graph(item, graph) (result,) = graph._iter_task_defs() diff --git a/python/lsst/pipe/base/pipelineIR.py b/python/lsst/pipe/base/pipelineIR.py index 6bc0ad2c7..47a9b1d03 100644 --- a/python/lsst/pipe/base/pipelineIR.py +++ b/python/lsst/pipe/base/pipelineIR.py @@ -45,10 +45,11 @@ from collections import Counter from collections.abc import Generator, Hashable, Iterable, MutableMapping from dataclasses import dataclass, field -from typing import Any, Literal +from typing import Any, Literal, cast import yaml from lsst.resources import ResourcePath, ResourcePathExpression +from lsst.utils import doImportType from lsst.utils.introspection import find_outside_stacklevel @@ -442,6 +443,34 @@ def __eq__(self, other: object) -> bool: return all(getattr(self, attr) == getattr(other, attr) for attr in ("label", "klass", "config")) +@dataclass +class _AmbigousTask: + """Representation of tasks which may have conflicting task classes.""" + + tasks: list[TaskIR] + """TaskIR objects that need to be compaired late.""" + + def resolve(self) -> TaskIR: + true_taskIR = self.tasks[0] + task_class = doImportType(true_taskIR.klass) + # need to find out if they are all actually the same + for tmp_taskIR in self.tasks[1:]: + tmp_task_class = doImportType(tmp_taskIR.klass) + if tmp_task_class is task_class: + if tmp_taskIR.config is None: + continue + for config in tmp_taskIR.config: + true_taskIR.add_or_update_config(config) + else: + true_taskIR = tmp_taskIR + task_class = tmp_task_class + return true_taskIR + + def to_primitives(self) -> dict[str, str | list[dict]]: + true_task = self.resolve() + return true_task.to_primitives() + + @dataclass class ImportIR: """An intermediate representation of imported pipelines.""" @@ -777,7 +806,7 @@ def merge_pipelines(self, pipelines: Iterable[PipelineIR]) -> None: existing in this object. """ # integrate any imported pipelines - accumulate_tasks: dict[str, TaskIR] = {} + accumulate_tasks: dict[str, TaskIR | _AmbigousTask] = {} accumulate_labeled_subsets: dict[str, LabeledSubset] = {} accumulated_parameters = ParametersIR({}) accumulated_steps: dict[str, StepIR] = {} @@ -841,17 +870,39 @@ def merge_pipelines(self, pipelines: Iterable[PipelineIR]) -> None: for label, task in self.tasks.items(): if label not in accumulate_tasks: accumulate_tasks[label] = task - elif accumulate_tasks[label].klass == task.klass: - if task.config is not None: - for config in task.config: - accumulate_tasks[label].add_or_update_config(config) else: - accumulate_tasks[label] = task - self.tasks: dict[str, TaskIR] = accumulate_tasks + match (accumulate_tasks[label], task): + case (TaskIR() as taskir_obj, TaskIR() as ctask) if taskir_obj.klass == ctask.klass: + if ctask.config is not None: + for config in ctask.config: + taskir_obj.add_or_update_config(config) + case (TaskIR(klass=klass) as taskir_obj, TaskIR() as ctask) if klass != ctask.klass: + accumulate_tasks[label] = _AmbigousTask([taskir_obj, ctask]) + case (_AmbigousTask(ambig_list), TaskIR() as ctask): + ambig_list.append(ctask) + case (TaskIR() as taskir_obj, _AmbigousTask(ambig_list)): + accumulate_tasks[label] = _AmbigousTask([taskir_obj] + ambig_list) + case (_AmbigousTask(existing_ambig_list), _AmbigousTask(new_ambig_list)): + existing_ambig_list.extend(new_ambig_list) + + self.tasks: MutableMapping[str, TaskIR | _AmbigousTask] = accumulate_tasks accumulated_parameters.update(self.parameters) self.parameters = accumulated_parameters self.steps = list(accumulated_steps.values()) + def resolve_task_ambiguity(self) -> None: + new_tasks: dict[str, TaskIR] = {} + for label, task in self.tasks.items(): + match task: + case TaskIR(): + new_tasks[label] = task + case _AmbigousTask(): + new_tasks[label] = task.resolve() + # Do a cast here, because within this function body we want the + # protection that all the tasks are TaskIR objects, but for the + # task level variable, it must stay the same mixed dictionary. + self.tasks = cast(dict[str, TaskIR | _AmbigousTask], new_tasks) + def _read_tasks(self, loaded_yaml: dict[str, Any]) -> None: """Process the tasks portion of the loaded yaml document @@ -869,6 +920,7 @@ def _read_tasks(self, loaded_yaml: dict[str, Any]) -> None: if "parameters" in tmp_tasks: raise ValueError("parameters is a reserved word and cannot be used as a task label") + definition: str | dict[str, Any] for label, definition in tmp_tasks.items(): if isinstance(definition, str): definition = {"class": definition} diff --git a/python/lsst/pipe/base/tests/pipelineIRTestClasses.py b/python/lsst/pipe/base/tests/pipelineIRTestClasses.py new file mode 100644 index 000000000..40eca4154 --- /dev/null +++ b/python/lsst/pipe/base/tests/pipelineIRTestClasses.py @@ -0,0 +1,48 @@ +# This file is part of pipe_base. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +"""Module defining PipelineIR test classes. +""" + +from __future__ import annotations + +__all__ = ("ModuleA", "ModuleAAlias", "ModuleAReplace") + + +class ModuleA: + """PipelineIR test class for importing.""" + + pass + + +ModuleAAlias = ModuleA + + +class ModuleAReplace: + """PipelineIR test class for importing.""" + + pass diff --git a/python/lsst/pipe/base/tests/simpleQGraph.py b/python/lsst/pipe/base/tests/simpleQGraph.py index 60cd86ebb..616f7b46b 100644 --- a/python/lsst/pipe/base/tests/simpleQGraph.py +++ b/python/lsst/pipe/base/tests/simpleQGraph.py @@ -199,6 +199,114 @@ def makeTask( return task +class SubTaskConnections( + PipelineTaskConnections, + dimensions=("instrument", "detector"), + defaultTemplates={"in_tmpl": "_in", "out_tmpl": "_out"}, +): + """Connections for SubTask, has one input and two outputs, + plus one init output. + """ + + input = cT.Input( + name="add_dataset{in_tmpl}", + dimensions=["instrument", "detector"], + storageClass="NumpyArray", + doc="Input dataset type for this task", + ) + output = cT.Output( + name="add_dataset{out_tmpl}", + dimensions=["instrument", "detector"], + storageClass="NumpyArray", + doc="Output dataset type for this task", + ) + output2 = cT.Output( + name="add2_dataset{out_tmpl}", + dimensions=["instrument", "detector"], + storageClass="NumpyArray", + doc="Output dataset type for this task", + ) + initout = cT.InitOutput( + name="add_init_output{out_tmpl}", + storageClass="NumpyArray", + doc="Init Output dataset type for this task", + ) + + +class SubTaskConfig(PipelineTaskConfig, pipelineConnections=SubTaskConnections): + """Config for SubTask.""" + + subtract = pexConfig.Field[int](doc="amount to subtract", default=3) + + +class SubTask(PipelineTask): + """Trivial PipelineTask for testing, has some extras useful for specific + unit tests. + """ + + ConfigClass = SubTaskConfig + _DefaultName = "sub_task" + + initout = numpy.array([999]) + """InitOutputs for this task""" + + taskFactory: SubTaskFactoryMock | None = None + """Factory that makes instances""" + + def run(self, input: int) -> Struct: + if self.taskFactory: + # do some bookkeeping + if self.taskFactory.stopAt == self.taskFactory.countExec: + raise RuntimeError("pretend something bad happened") + self.taskFactory.countExec -= 1 + + self.config = cast(SubTaskConfig, self.config) + self.metadata.add("sub", self.config.subtract) + output = input - self.config.subtract + output2 = output + self.config.subtract + _LOG.info("input = %s, output = %s, output2 = %s", input, output, output2) + return Struct(output=output, output2=output2) + + +class SubTaskFactoryMock(TaskFactory): + """Special task factory that instantiates AddTask. + + It also defines some bookkeeping variables used by SubTask to report + progress to unit tests. + + Parameters + ---------- + stopAt : `int`, optional + Number of times to call `run` before stopping. + """ + + def __init__(self, stopAt: int = -1): + self.countExec = 100 # reduced by SubTask + self.stopAt = stopAt # AddTask raises exception at this call to run() + + def makeTask( + self, + task_node: TaskDef | TaskNode, + /, + butler: LimitedButler, + initInputRefs: Iterable[DatasetRef] | None, + ) -> PipelineTask: + if isinstance(task_node, TaskDef): + # TODO: remove support on DM-40443. + warnings.warn( + "Passing TaskDef to TaskFactory is deprecated and will not be supported after v27.", + FutureWarning, + find_outside_stacklevel("lsst.pipe.base"), + ) + task_class = task_node.taskClass + assert task_class is not None + else: + task_class = task_node.task_class + task = task_class(config=task_node.config, initInputs=None, name=task_node.label) + task.taskFactory = self # type: ignore + return task + + def registerDatasetTypes(registry: Registry, pipeline: Pipeline | Iterable[TaskDef] | PipelineGraph) -> None: """Register all dataset types used by tasks in a registry. diff --git a/tests/testPipeline2.yaml b/tests/testPipeline2.yaml index 96c96909c..df3f61aaa 100644 --- a/tests/testPipeline2.yaml +++ b/tests/testPipeline2.yaml @@ -4,7 +4,7 @@ parameters: value3: valueC tasks: modA: - class: "test.moduleA" + class: "lsst.pipe.base.tests.pipelineIRTestClasses.ModuleA" config: value1: 1 subsets: diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index a59aa9cc8..eca8e6cd0 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -35,7 +35,7 @@ import lsst.utils.tests from lsst.pipe.base import LabelSpecifier, Pipeline, TaskDef from lsst.pipe.base.pipelineIR import LabeledSubset -from lsst.pipe.base.tests.simpleQGraph import AddTask, makeSimplePipeline +from lsst.pipe.base.tests.simpleQGraph import AddTask, SubTask, makeSimplePipeline class PipelineTestCase(unittest.TestCase): @@ -131,6 +131,48 @@ def testMergingPipelines(self): pipeline1.mergePipeline(pipeline2) self.assertEqual(pipeline1._pipelineIR.tasks.keys(), {"task0", "task1", "task2", "task3"}) + # Test merging pipelines with ambiguous tasks + pipeline1 = makeSimplePipeline(2) + pipeline2 = makeSimplePipeline(2) + pipeline2.addTask(SubTask, "task1") + pipeline2.mergePipeline(pipeline1) + + # Now merge in another pipeline with a config applied. + pipeline3 = makeSimplePipeline(2) + pipeline3.addTask(SubTask, "task1") + pipeline3.addConfigOverride("task1", "subtract", 10) + pipeline3.mergePipeline(pipeline2) + graph = pipeline3.to_graph() + # assert equality from the graph to trigger ambiquity resolution + self.assertEqual(graph.tasks["task1"].config.subtract, 10) + + # Now change the order of the merging + pipeline1 = makeSimplePipeline(2) + pipeline2 = makeSimplePipeline(2) + pipeline2.addTask(SubTask, "task1") + pipeline3 = makeSimplePipeline(2) + pipeline3.mergePipeline(pipeline2) + pipeline3.mergePipeline(pipeline1) + graph = pipeline3.to_graph() + # assert equality from the graph to trigger ambiquity resolution + self.assertEqual(graph.tasks["task1"].config.addend, 3) + + # Now do two ambiguous chains + pipeline1 = makeSimplePipeline(2) + pipeline2 = makeSimplePipeline(2) + pipeline2.addTask(SubTask, "task1") + pipeline2.addConfigOverride("task1", "subtract", 10) + pipeline2.mergePipeline(pipeline1) + + pipeline3 = makeSimplePipeline(2) + pipeline4 = makeSimplePipeline(2) + pipeline4.addTask(SubTask, "task1") + pipeline4.addConfigOverride("task1", "subtract", 7) + pipeline4.mergePipeline(pipeline3) + graph = pipeline4.to_graph() + # assert equality from the graph to trigger ambiquity resolution + self.assertEqual(graph.tasks["task1"].config.subtract, 7) + def testFindingSubset(self): pipeline = makeSimplePipeline(2) pipeline._pipelineIR.labeled_subsets["test1"] = LabeledSubset("test1", set(), None) diff --git a/tests/test_pipelineIR.py b/tests/test_pipelineIR.py index 36c321135..def304fdd 100644 --- a/tests/test_pipelineIR.py +++ b/tests/test_pipelineIR.py @@ -230,12 +230,31 @@ def testImportParsing(self): - $TESTDIR/testPipeline2.yaml tasks: modA: - class: "test.moduleA" + class: "lsst.pipe.base.tests.pipelineIRTestClasses.ModuleA" config: value2: 2 """ ) pipeline = PipelineIR.from_string(pipeline_str) + pipeline.resolve_task_ambiguity() + self.assertEqual(pipeline.tasks["modA"].config[0].rest, {"value1": 1, "value2": 2}) + + # Test that configs are imported when defining the same task again + # that is aliased with the same label + pipeline_str = textwrap.dedent( + """ + description: Test Pipeline + imports: + - $TESTDIR/testPipeline2.yaml + tasks: + modA: + class: "lsst.pipe.base.tests.pipelineIRTestClasses.ModuleAAlias" + config: + value2: 2 + """ + ) + pipeline = PipelineIR.from_string(pipeline_str) + pipeline.resolve_task_ambiguity() self.assertEqual(pipeline.tasks["modA"].config[0].rest, {"value1": 1, "value2": 2}) # Test that configs are not imported when redefining the task @@ -247,12 +266,13 @@ def testImportParsing(self): - $TESTDIR/testPipeline2.yaml tasks: modA: - class: "test.moduleAReplace" + class: "lsst.pipe.base.tests.pipelineIRTestClasses.ModuleAReplace" config: value2: 2 """ ) pipeline = PipelineIR.from_string(pipeline_str) + pipeline.resolve_task_ambiguity() self.assertEqual(pipeline.tasks["modA"].config[0].rest, {"value2": 2}) # Test that named subsets are imported