Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-44648: Allow conflicting module paths to be resolved late #421

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions python/lsst/pipe/base/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,7 +806,11 @@
return
if label not in self._pipelineIR.tasks:
raise LookupError(f"There are no tasks labeled '{label}' in the pipeline")
self._pipelineIR.tasks[label].add_or_update_config(newConfig)
match self._pipelineIR.tasks[label]:
case pipelineIR.TaskIR() as task:
task.add_or_update_config(newConfig)
case pipelineIR._AmbigousTask() as ambig_task:
ambig_task.tasks[-1].add_or_update_config(newConfig)

Check warning on line 813 in python/lsst/pipe/base/pipeline.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/pipeline.py#L812-L813

Added lines #L812 - L813 were not covered by tests

def write_to_uri(self, uri: ResourcePathExpression) -> None:
"""Write the pipeline to a file or directory.
Expand Down Expand Up @@ -839,6 +843,7 @@
graph : `pipeline_graph.PipelineGraph`
Representation of the pipeline as a graph.
"""
self._pipelineIR.resolve_task_ambiguity()
instrument_class_name = self._pipelineIR.instrument
data_id = {}
if instrument_class_name is not None:
Expand Down Expand Up @@ -906,7 +911,8 @@
"""
if (taskIR := self._pipelineIR.tasks.get(label)) is None:
raise NameError(f"Label {label} does not appear in this pipeline")
taskClass: type[PipelineTask] = doImportType(taskIR.klass)
# type ignore here because all ambiguity should be resolved
taskClass: type[PipelineTask] = doImportType(taskIR.klass) # type: ignore
config = taskClass.ConfigClass()
instrument: PipeBaseInstrument | None = None
if (instrumentName := self._pipelineIR.instrument) is not None:
Expand All @@ -915,7 +921,8 @@
config.applyConfigOverrides(
instrument,
getattr(taskClass, "_DefaultName", ""),
taskIR.config,
# type ignore here because all ambiguity should be resolved
taskIR.config, # type: ignore
self._pipelineIR.parameters,
label,
)
Expand All @@ -940,6 +947,7 @@
# Making a whole graph and then making a TaskDef from that is pretty
# backwards, but I'm hoping to deprecate this method shortly in favor
# of making the graph explicitly and working with its node objects.
self._pipelineIR.resolve_task_ambiguity()

Check warning on line 950 in python/lsst/pipe/base/pipeline.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/pipeline.py#L950

Added line #L950 was not covered by tests
graph = pipeline_graph.PipelineGraph()
self._add_task_to_graph(item, graph)
(result,) = graph._iter_task_defs()
Expand Down
68 changes: 60 additions & 8 deletions python/lsst/pipe/base/pipelineIR.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,12 @@
from collections import Counter
from collections.abc import Generator, Hashable, Iterable, MutableMapping
from dataclasses import dataclass, field
from typing import Any, Literal
from typing import Any, Literal, cast

import yaml
from lsst.resources import ResourcePath, ResourcePathExpression
from lsst.utils.introspection import find_outside_stacklevel
from lsst.utils import doImportType


class PipelineSubsetCtrl(enum.Enum):
Expand Down Expand Up @@ -442,6 +443,34 @@
return all(getattr(self, attr) == getattr(other, attr) for attr in ("label", "klass", "config"))


@dataclass
class _AmbigousTask:
"""Representation of tasks which may have conflicting task classes."""

tasks: list[TaskIR]
"""TaskIR objects that need to be compaired late."""

def resolve(self) -> TaskIR:
true_taskIR = self.tasks[0]
task_class = doImportType(true_taskIR.klass)
# need to find out if they are all actually the same
for tmp_taskIR in self.tasks[1:]:
tmp_task_class = doImportType(tmp_taskIR.klass)
if tmp_task_class is task_class:
if tmp_taskIR.config is None:
continue

Check warning on line 461 in python/lsst/pipe/base/pipelineIR.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/pipelineIR.py#L461

Added line #L461 was not covered by tests
for config in tmp_taskIR.config:
true_taskIR.add_or_update_config(config)
else:
true_taskIR = tmp_taskIR
task_class = tmp_task_class
return true_taskIR

def to_primitives(self) -> dict[str, str | list[dict]]:
true_task = self.resolve()
return true_task.to_primitives()

Check warning on line 471 in python/lsst/pipe/base/pipelineIR.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/pipelineIR.py#L470-L471

Added lines #L470 - L471 were not covered by tests


@dataclass
class ImportIR:
"""An intermediate representation of imported pipelines."""
Expand Down Expand Up @@ -777,7 +806,7 @@
existing in this object.
"""
# integrate any imported pipelines
accumulate_tasks: dict[str, TaskIR] = {}
accumulate_tasks: dict[str, TaskIR | _AmbigousTask] = {}
accumulate_labeled_subsets: dict[str, LabeledSubset] = {}
accumulated_parameters = ParametersIR({})
accumulated_steps: dict[str, StepIR] = {}
Expand Down Expand Up @@ -841,17 +870,39 @@
for label, task in self.tasks.items():
if label not in accumulate_tasks:
accumulate_tasks[label] = task
elif accumulate_tasks[label].klass == task.klass:
if task.config is not None:
for config in task.config:
accumulate_tasks[label].add_or_update_config(config)
else:
accumulate_tasks[label] = task
self.tasks: dict[str, TaskIR] = accumulate_tasks
match (accumulate_tasks[label], task):
case (TaskIR() as taskir_obj, TaskIR() as ctask) if taskir_obj.klass == ctask.klass:
if ctask.config is not None:
for config in ctask.config:
taskir_obj.add_or_update_config(config)
case (TaskIR(klass=klass) as taskir_obj, TaskIR() as ctask) if klass != ctask.klass:
accumulate_tasks[label] = _AmbigousTask([taskir_obj, ctask])
case (_AmbigousTask(ambig_list), TaskIR() as ctask):
ambig_list.append(ctask)

Check warning on line 882 in python/lsst/pipe/base/pipelineIR.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/pipelineIR.py#L882

Added line #L882 was not covered by tests
case (TaskIR() as taskir_obj, _AmbigousTask(ambig_list)):
accumulate_tasks[label] = _AmbigousTask([taskir_obj] + ambig_list)

Check warning on line 884 in python/lsst/pipe/base/pipelineIR.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/pipelineIR.py#L884

Added line #L884 was not covered by tests
case (_AmbigousTask(existing_ambig_list), _AmbigousTask(new_ambig_list)):
existing_ambig_list.extend(new_ambig_list)

Check warning on line 886 in python/lsst/pipe/base/pipelineIR.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/pipelineIR.py#L886

Added line #L886 was not covered by tests

self.tasks: MutableMapping[str, TaskIR | _AmbigousTask] = accumulate_tasks
accumulated_parameters.update(self.parameters)
self.parameters = accumulated_parameters
self.steps = list(accumulated_steps.values())

def resolve_task_ambiguity(self) -> None:
new_tasks: dict[str, TaskIR] = {}
for label, task in self.tasks.items():
match task:
case TaskIR():
new_tasks[label] = task
case _AmbigousTask():
new_tasks[label] = task.resolve()
# Do a cast here, because within this function body we want the
# protection that all the tasks are TaskIR objects, but for the
# task level variable, it must stay the same mixed dictionary.
self.tasks = cast(dict[str, TaskIR | _AmbigousTask], new_tasks)

def _read_tasks(self, loaded_yaml: dict[str, Any]) -> None:
"""Process the tasks portion of the loaded yaml document

Expand All @@ -869,6 +920,7 @@
if "parameters" in tmp_tasks:
raise ValueError("parameters is a reserved word and cannot be used as a task label")

definition: str | dict[str, Any]
for label, definition in tmp_tasks.items():
if isinstance(definition, str):
definition = {"class": definition}
Expand Down
48 changes: 48 additions & 0 deletions python/lsst/pipe/base/tests/pipelineIRTestClasses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# This file is part of pipe_base.
#
# Developed for the LSST Data Management System.
# This product includes software developed by the LSST Project
# (http://www.lsst.org).
# See the COPYRIGHT file at the top-level directory of this distribution
# for details of code ownership.
#
# This software is dual licensed under the GNU General Public License and also
# under a 3-clause BSD license. Recipients may choose which of these licenses
# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
# respectively. If you choose the GPL option then the following text applies
# (but note that there is still no warranty even if you opt for BSD instead):
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

"""Module defining PipelineIR test classes
"""

from __future__ import annotations

__all__ = ("ModuleA", "ModuleAAlias", "ModuleAReplace")


class ModuleA:
"""PipelineIR test class for importing"""

pass


ModuleAAlias = ModuleA


class ModuleAReplace:
"""PipelineIR test class for importing"""

pass
2 changes: 1 addition & 1 deletion tests/testPipeline2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ parameters:
value3: valueC
tasks:
modA:
class: "test.moduleA"
class: "lsst.pipe.base.tests.pipelineIRTestClasses.ModuleA"
config:
value1: 1
subsets:
Expand Down
24 changes: 22 additions & 2 deletions tests/test_pipelineIR.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,12 +230,31 @@ def testImportParsing(self):
- $TESTDIR/testPipeline2.yaml
tasks:
modA:
class: "test.moduleA"
class: "lsst.pipe.base.tests.pipelineIRTestClasses.ModuleA"
config:
value2: 2
"""
)
pipeline = PipelineIR.from_string(pipeline_str)
pipeline.resolve_task_ambiguity()
self.assertEqual(pipeline.tasks["modA"].config[0].rest, {"value1": 1, "value2": 2})

# Test that configs are imported when defining the same task again
# that is aliased with the same label
pipeline_str = textwrap.dedent(
"""
description: Test Pipeline
imports:
- $TESTDIR/testPipeline2.yaml
tasks:
modA:
class: "lsst.pipe.base.tests.pipelineIRTestClasses.ModuleAAlias"
config:
value2: 2
"""
)
pipeline = PipelineIR.from_string(pipeline_str)
pipeline.resolve_task_ambiguity()
self.assertEqual(pipeline.tasks["modA"].config[0].rest, {"value1": 1, "value2": 2})

# Test that configs are not imported when redefining the task
Expand All @@ -247,12 +266,13 @@ def testImportParsing(self):
- $TESTDIR/testPipeline2.yaml
tasks:
modA:
class: "test.moduleAReplace"
class: "lsst.pipe.base.tests.pipelineIRTestClasses.ModuleAReplace"
config:
value2: 2
"""
)
pipeline = PipelineIR.from_string(pipeline_str)
pipeline.resolve_task_ambiguity()
self.assertEqual(pipeline.tasks["modA"].config[0].rest, {"value2": 2})

# Test that named subsets are imported
Expand Down
Loading