Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-44648: Allow conflicting module paths to be resolved late #421

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions python/lsst/pipe/base/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,7 +806,11 @@
return
if label not in self._pipelineIR.tasks:
raise LookupError(f"There are no tasks labeled '{label}' in the pipeline")
self._pipelineIR.tasks[label].add_or_update_config(newConfig)
match self._pipelineIR.tasks[label]:
case pipelineIR.TaskIR() as task:
task.add_or_update_config(newConfig)
case pipelineIR._AmbigousTask() as ambig_task:

Check warning on line 812 in python/lsst/pipe/base/pipeline.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/pipeline.py#L812

Added line #L812 was not covered by tests
ambig_task.tasks[-1].add_or_update_config(newConfig)

def write_to_uri(self, uri: ResourcePathExpression) -> None:
"""Write the pipeline to a file or directory.
Expand Down Expand Up @@ -839,6 +843,7 @@
graph : `pipeline_graph.PipelineGraph`
Representation of the pipeline as a graph.
"""
self._pipelineIR.resolve_task_ambiguity()
instrument_class_name = self._pipelineIR.instrument
data_id = {}
if instrument_class_name is not None:
Expand Down Expand Up @@ -906,7 +911,8 @@
"""
if (taskIR := self._pipelineIR.tasks.get(label)) is None:
raise NameError(f"Label {label} does not appear in this pipeline")
taskClass: type[PipelineTask] = doImportType(taskIR.klass)
# type ignore here because all ambiguity should be resolved
taskClass: type[PipelineTask] = doImportType(taskIR.klass) # type: ignore
config = taskClass.ConfigClass()
instrument: PipeBaseInstrument | None = None
if (instrumentName := self._pipelineIR.instrument) is not None:
Expand All @@ -915,7 +921,8 @@
config.applyConfigOverrides(
instrument,
getattr(taskClass, "_DefaultName", ""),
taskIR.config,
# type ignore here because all ambiguity should be resolved
taskIR.config, # type: ignore
self._pipelineIR.parameters,
label,
)
Expand All @@ -940,6 +947,7 @@
# Making a whole graph and then making a TaskDef from that is pretty
# backwards, but I'm hoping to deprecate this method shortly in favor
# of making the graph explicitly and working with its node objects.
self._pipelineIR.resolve_task_ambiguity()
graph = pipeline_graph.PipelineGraph()
self._add_task_to_graph(item, graph)
(result,) = graph._iter_task_defs()
Expand Down
68 changes: 60 additions & 8 deletions python/lsst/pipe/base/pipelineIR.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,11 @@
from collections import Counter
from collections.abc import Generator, Hashable, Iterable, MutableMapping
from dataclasses import dataclass, field
from typing import Any, Literal
from typing import Any, Literal, cast

import yaml
from lsst.resources import ResourcePath, ResourcePathExpression
from lsst.utils import doImportType
from lsst.utils.introspection import find_outside_stacklevel


Expand Down Expand Up @@ -442,6 +443,34 @@
return all(getattr(self, attr) == getattr(other, attr) for attr in ("label", "klass", "config"))


@dataclass
class _AmbigousTask:
"""Representation of tasks which may have conflicting task classes."""

tasks: list[TaskIR]
"""TaskIR objects that need to be compaired late."""

def resolve(self) -> TaskIR:
true_taskIR = self.tasks[0]
task_class = doImportType(true_taskIR.klass)
# need to find out if they are all actually the same
for tmp_taskIR in self.tasks[1:]:
tmp_task_class = doImportType(tmp_taskIR.klass)
if tmp_task_class is task_class:
if tmp_taskIR.config is None:
continue

Check warning on line 461 in python/lsst/pipe/base/pipelineIR.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/pipelineIR.py#L461

Added line #L461 was not covered by tests
for config in tmp_taskIR.config:
true_taskIR.add_or_update_config(config)
else:
true_taskIR = tmp_taskIR
task_class = tmp_task_class
return true_taskIR

def to_primitives(self) -> dict[str, str | list[dict]]:
true_task = self.resolve()
return true_task.to_primitives()

Check warning on line 471 in python/lsst/pipe/base/pipelineIR.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/pipelineIR.py#L470-L471

Added lines #L470 - L471 were not covered by tests


@dataclass
class ImportIR:
"""An intermediate representation of imported pipelines."""
Expand Down Expand Up @@ -777,7 +806,7 @@
existing in this object.
"""
# integrate any imported pipelines
accumulate_tasks: dict[str, TaskIR] = {}
accumulate_tasks: dict[str, TaskIR | _AmbigousTask] = {}
accumulate_labeled_subsets: dict[str, LabeledSubset] = {}
accumulated_parameters = ParametersIR({})
accumulated_steps: dict[str, StepIR] = {}
Expand Down Expand Up @@ -841,17 +870,39 @@
for label, task in self.tasks.items():
if label not in accumulate_tasks:
accumulate_tasks[label] = task
elif accumulate_tasks[label].klass == task.klass:
if task.config is not None:
for config in task.config:
accumulate_tasks[label].add_or_update_config(config)
else:
accumulate_tasks[label] = task
self.tasks: dict[str, TaskIR] = accumulate_tasks
match (accumulate_tasks[label], task):
case (TaskIR() as taskir_obj, TaskIR() as ctask) if taskir_obj.klass == ctask.klass:
if ctask.config is not None:
for config in ctask.config:
taskir_obj.add_or_update_config(config)
case (TaskIR(klass=klass) as taskir_obj, TaskIR() as ctask) if klass != ctask.klass:
accumulate_tasks[label] = _AmbigousTask([taskir_obj, ctask])
case (_AmbigousTask(ambig_list), TaskIR() as ctask):
ambig_list.append(ctask)
case (TaskIR() as taskir_obj, _AmbigousTask(ambig_list)):
accumulate_tasks[label] = _AmbigousTask([taskir_obj] + ambig_list)
case (_AmbigousTask(existing_ambig_list), _AmbigousTask(new_ambig_list)):
existing_ambig_list.extend(new_ambig_list)

Check warning on line 886 in python/lsst/pipe/base/pipelineIR.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/pipelineIR.py#L886

Added line #L886 was not covered by tests

self.tasks: MutableMapping[str, TaskIR | _AmbigousTask] = accumulate_tasks
accumulated_parameters.update(self.parameters)
self.parameters = accumulated_parameters
self.steps = list(accumulated_steps.values())

def resolve_task_ambiguity(self) -> None:
new_tasks: dict[str, TaskIR] = {}
for label, task in self.tasks.items():
match task:
case TaskIR():
new_tasks[label] = task
case _AmbigousTask():
new_tasks[label] = task.resolve()
# Do a cast here, because within this function body we want the
# protection that all the tasks are TaskIR objects, but for the
# task level variable, it must stay the same mixed dictionary.
self.tasks = cast(dict[str, TaskIR | _AmbigousTask], new_tasks)

def _read_tasks(self, loaded_yaml: dict[str, Any]) -> None:
"""Process the tasks portion of the loaded yaml document

Expand All @@ -869,6 +920,7 @@
if "parameters" in tmp_tasks:
raise ValueError("parameters is a reserved word and cannot be used as a task label")

definition: str | dict[str, Any]
for label, definition in tmp_tasks.items():
if isinstance(definition, str):
definition = {"class": definition}
Expand Down
48 changes: 48 additions & 0 deletions python/lsst/pipe/base/tests/pipelineIRTestClasses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# This file is part of pipe_base.
#
# Developed for the LSST Data Management System.
# This product includes software developed by the LSST Project
# (http://www.lsst.org).
# See the COPYRIGHT file at the top-level directory of this distribution
# for details of code ownership.
#
# This software is dual licensed under the GNU General Public License and also
# under a 3-clause BSD license. Recipients may choose which of these licenses
# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
# respectively. If you choose the GPL option then the following text applies
# (but note that there is still no warranty even if you opt for BSD instead):
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

"""Module defining PipelineIR test classes.
"""

from __future__ import annotations

__all__ = ("ModuleA", "ModuleAAlias", "ModuleAReplace")


class ModuleA:
"""PipelineIR test class for importing."""

pass


ModuleAAlias = ModuleA


class ModuleAReplace:
"""PipelineIR test class for importing."""

pass
108 changes: 108 additions & 0 deletions python/lsst/pipe/base/tests/simpleQGraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,114 @@
return task


class SubTaskConnections(
PipelineTaskConnections,
dimensions=("instrument", "detector"),
defaultTemplates={"in_tmpl": "_in", "out_tmpl": "_out"},
):
"""Connections for SubTask, has one input and two outputs,
plus one init output.
"""

input = cT.Input(
name="add_dataset{in_tmpl}",
dimensions=["instrument", "detector"],
storageClass="NumpyArray",
doc="Input dataset type for this task",
)
output = cT.Output(
name="add_dataset{out_tmpl}",
dimensions=["instrument", "detector"],
storageClass="NumpyArray",
doc="Output dataset type for this task",
)
output2 = cT.Output(
name="add2_dataset{out_tmpl}",
dimensions=["instrument", "detector"],
storageClass="NumpyArray",
doc="Output dataset type for this task",
)
initout = cT.InitOutput(
name="add_init_output{out_tmpl}",
storageClass="NumpyArray",
doc="Init Output dataset type for this task",
)


class SubTaskConfig(PipelineTaskConfig, pipelineConnections=SubTaskConnections):
"""Config for SubTask."""

subtract = pexConfig.Field[int](doc="amount to subtract", default=3)


class SubTask(PipelineTask):
"""Trivial PipelineTask for testing, has some extras useful for specific
unit tests.
"""

ConfigClass = SubTaskConfig
_DefaultName = "sub_task"

initout = numpy.array([999])
"""InitOutputs for this task"""

taskFactory: SubTaskFactoryMock | None = None
"""Factory that makes instances"""

def run(self, input: int) -> Struct:
if self.taskFactory:
# do some bookkeeping
if self.taskFactory.stopAt == self.taskFactory.countExec:
raise RuntimeError("pretend something bad happened")
self.taskFactory.countExec -= 1

Check warning on line 261 in python/lsst/pipe/base/tests/simpleQGraph.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/tests/simpleQGraph.py#L260-L261

Added lines #L260 - L261 were not covered by tests

self.config = cast(SubTaskConfig, self.config)
self.metadata.add("sub", self.config.subtract)
output = input - self.config.subtract
output2 = output + self.config.subtract
_LOG.info("input = %s, output = %s, output2 = %s", input, output, output2)
return Struct(output=output, output2=output2)

Check warning on line 268 in python/lsst/pipe/base/tests/simpleQGraph.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/tests/simpleQGraph.py#L263-L268

Added lines #L263 - L268 were not covered by tests


class SubTaskFactoryMock(TaskFactory):
"""Special task factory that instantiates AddTask.

It also defines some bookkeeping variables used by SubTask to report
progress to unit tests.

Parameters
----------
stopAt : `int`, optional
Number of times to call `run` before stopping.
"""

def __init__(self, stopAt: int = -1):
self.countExec = 100 # reduced by SubTask
self.stopAt = stopAt # AddTask raises exception at this call to run()

Check warning on line 285 in python/lsst/pipe/base/tests/simpleQGraph.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/tests/simpleQGraph.py#L284-L285

Added lines #L284 - L285 were not covered by tests

def makeTask(
self,
task_node: TaskDef | TaskNode,
/,
butler: LimitedButler,
initInputRefs: Iterable[DatasetRef] | None,
) -> PipelineTask:
if isinstance(task_node, TaskDef):
# TODO: remove support on DM-40443.
warnings.warn(

Check warning on line 296 in python/lsst/pipe/base/tests/simpleQGraph.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/tests/simpleQGraph.py#L296

Added line #L296 was not covered by tests
"Passing TaskDef to TaskFactory is deprecated and will not be supported after v27.",
FutureWarning,
find_outside_stacklevel("lsst.pipe.base"),
)
task_class = task_node.taskClass
assert task_class is not None

Check warning on line 302 in python/lsst/pipe/base/tests/simpleQGraph.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/tests/simpleQGraph.py#L301-L302

Added lines #L301 - L302 were not covered by tests
else:
task_class = task_node.task_class
task = task_class(config=task_node.config, initInputs=None, name=task_node.label)
task.taskFactory = self # type: ignore
return task

Check warning on line 307 in python/lsst/pipe/base/tests/simpleQGraph.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/tests/simpleQGraph.py#L304-L307

Added lines #L304 - L307 were not covered by tests


def registerDatasetTypes(registry: Registry, pipeline: Pipeline | Iterable[TaskDef] | PipelineGraph) -> None:
"""Register all dataset types used by tasks in a registry.

Expand Down
2 changes: 1 addition & 1 deletion tests/testPipeline2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ parameters:
value3: valueC
tasks:
modA:
class: "test.moduleA"
class: "lsst.pipe.base.tests.pipelineIRTestClasses.ModuleA"
config:
value1: 1
subsets:
Expand Down
44 changes: 43 additions & 1 deletion tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
import lsst.utils.tests
from lsst.pipe.base import LabelSpecifier, Pipeline, TaskDef
from lsst.pipe.base.pipelineIR import LabeledSubset
from lsst.pipe.base.tests.simpleQGraph import AddTask, makeSimplePipeline
from lsst.pipe.base.tests.simpleQGraph import AddTask, SubTask, makeSimplePipeline


class PipelineTestCase(unittest.TestCase):
Expand Down Expand Up @@ -131,6 +131,48 @@ def testMergingPipelines(self):
pipeline1.mergePipeline(pipeline2)
self.assertEqual(pipeline1._pipelineIR.tasks.keys(), {"task0", "task1", "task2", "task3"})

# Test merging pipelines with ambiguous tasks
pipeline1 = makeSimplePipeline(2)
pipeline2 = makeSimplePipeline(2)
pipeline2.addTask(SubTask, "task1")
pipeline2.mergePipeline(pipeline1)

# Now merge in another pipeline with a config applied.
pipeline3 = makeSimplePipeline(2)
pipeline3.addTask(SubTask, "task1")
pipeline3.addConfigOverride("task1", "subtract", 10)
pipeline3.mergePipeline(pipeline2)
graph = pipeline3.to_graph()
# assert equality from the graph to trigger ambiquity resolution
self.assertEqual(graph.tasks["task1"].config.subtract, 10)

# Now change the order of the merging
pipeline1 = makeSimplePipeline(2)
pipeline2 = makeSimplePipeline(2)
pipeline2.addTask(SubTask, "task1")
pipeline3 = makeSimplePipeline(2)
pipeline3.mergePipeline(pipeline2)
pipeline3.mergePipeline(pipeline1)
graph = pipeline3.to_graph()
# assert equality from the graph to trigger ambiquity resolution
self.assertEqual(graph.tasks["task1"].config.addend, 3)

# Now do two ambiguous chains
pipeline1 = makeSimplePipeline(2)
pipeline2 = makeSimplePipeline(2)
pipeline2.addTask(SubTask, "task1")
pipeline2.addConfigOverride("task1", "subtract", 10)
pipeline2.mergePipeline(pipeline1)

pipeline3 = makeSimplePipeline(2)
pipeline4 = makeSimplePipeline(2)
pipeline4.addTask(SubTask, "task1")
pipeline4.addConfigOverride("task1", "subtract", 7)
pipeline4.mergePipeline(pipeline3)
graph = pipeline4.to_graph()
# assert equality from the graph to trigger ambiquity resolution
self.assertEqual(graph.tasks["task1"].config.subtract, 7)

def testFindingSubset(self):
pipeline = makeSimplePipeline(2)
pipeline._pipelineIR.labeled_subsets["test1"] = LabeledSubset("test1", set(), None)
Expand Down
Loading
Loading