diff --git a/python/lsst/pipe/base/dot_tools.py b/python/lsst/pipe/base/dot_tools.py
new file mode 100644
index 000000000..01e24e0f1
--- /dev/null
+++ b/python/lsst/pipe/base/dot_tools.py
@@ -0,0 +1,351 @@
+# This file is part of pipe_base.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (http://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This software is dual licensed under the GNU General Public License and also
+# under a 3-clause BSD license. Recipients may choose which of these licenses
+# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
+# respectively. If you choose the GPL option then the following text applies
+# (but note that there is still no warranty even if you opt for BSD instead):
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+"""Module defining few methods to generate GraphViz diagrams from pipelines
+or quantum graphs.
+"""
+
+from __future__ import annotations
+
+__all__ = ["graph2dot", "pipeline2dot"]
+
+# -------------------------------
+# Imports of standard modules --
+# -------------------------------
+import html
+import io
+import re
+from collections.abc import Iterable
+from typing import TYPE_CHECKING, Any
+
+# -----------------------------
+# Imports for other modules --
+# -----------------------------
+from lsst.daf.butler import DatasetType, DimensionUniverse
+
+from . import connectionTypes
+from .connections import iterConnections
+from .pipeline import Pipeline
+
+if TYPE_CHECKING:
+ from lsst.daf.butler import DatasetRef
+ from lsst.pipe.base import QuantumGraph, QuantumNode, TaskDef
+
+# ----------------------------------
+# Local non-exported definitions --
+# ----------------------------------
+
+# Attributes applied to directed graph objects.
+_NODELABELPOINTSIZE = "18"
+_ATTRIBS = dict(
+ defaultGraph=dict(splines="ortho", nodesep="0.5", ranksep="0.75", pad="0.5"),
+ defaultNode=dict(shape="box", fontname="Monospace", fontsize="14", margin="0.2,0.1", penwidth="3"),
+ defaultEdge=dict(color="black", arrowsize="1.5", penwidth="1.5"),
+ task=dict(style="filled", color="black", fillcolor="#B1F2EF"),
+ quantum=dict(style="filled", color="black", fillcolor="#B1F2EF"),
+ dsType=dict(style="rounded,filled,bold", color="#00BABC", fillcolor="#F5F5F5"),
+ dataset=dict(style="rounded,filled,bold", color="#00BABC", fillcolor="#F5F5F5"),
+)
+
+
+def _renderDefault(type: str, attribs: dict[str, str], file: io.TextIOBase) -> None:
+ """Set default attributes for a given type."""
+ default_attribs = ", ".join([f'{key}="{val}"' for key, val in attribs.items()])
+ print(f"{type} [{default_attribs}];", file=file)
+
+
+def _renderNode(file: io.TextIOBase, nodeName: str, style: str, labels: list[str]) -> None:
+ """Render GV node"""
+ label = r"
".join(labels)
+ attrib_dict = dict(_ATTRIBS[style], label=label)
+ pre = '<>"
+ attrib = ", ".join(
+ [
+ f'{key}="{val}"' if key != "label" else f"{key}={pre}{val}{post}"
+ for key, val in attrib_dict.items()
+ ]
+ )
+ print(f'"{nodeName}" [{attrib}];', file=file)
+
+
+def _renderTaskNode(nodeName: str, taskDef: TaskDef, file: io.TextIOBase, idx: Any = None) -> None:
+ """Render GV node for a task"""
+ labels = [
+ f'' + html.escape(taskDef.label) + "",
+ html.escape(taskDef.taskName),
+ ]
+ if idx is not None:
+ labels.append(f"index: {idx}")
+ if taskDef.connections:
+ # don't print collection of str directly to avoid visually noisy quotes
+ dimensions_str = ", ".join(sorted(taskDef.connections.dimensions))
+ labels.append(f"dimensions: {html.escape(dimensions_str)}")
+ _renderNode(file, nodeName, "task", labels)
+
+
+def _renderQuantumNode(
+ nodeName: str, taskDef: TaskDef, quantumNode: QuantumNode, file: io.TextIOBase
+) -> None:
+ """Render GV node for a quantum"""
+ labels = [f"{quantumNode.nodeId}", html.escape(taskDef.label)]
+ dataId = quantumNode.quantum.dataId
+ assert dataId is not None, "Quantum DataId cannot be None"
+ labels.extend(f"{key} = {dataId[key]}" for key in sorted(dataId.required.keys()))
+ _renderNode(file, nodeName, "quantum", labels)
+
+
+def _renderDSTypeNode(name: str, dimensions: list[str], file: io.TextIOBase) -> None:
+ """Render GV node for a dataset type"""
+ labels = [f'' + html.escape(name) + ""]
+ if dimensions:
+ labels.append("dimensions: " + html.escape(", ".join(sorted(dimensions))))
+ _renderNode(file, name, "dsType", labels)
+
+
+def _renderDSNode(nodeName: str, dsRef: DatasetRef, file: io.TextIOBase) -> None:
+ """Render GV node for a dataset"""
+ labels = [html.escape(dsRef.datasetType.name), f"run: {dsRef.run!r}"]
+ labels.extend(f"{key} = {dsRef.dataId[key]}" for key in sorted(dsRef.dataId.required.keys()))
+ _renderNode(file, nodeName, "dataset", labels)
+
+
+def _renderEdge(fromName: str, toName: str, file: io.TextIOBase, **kwargs: Any) -> None:
+ """Render GV edge"""
+ if kwargs:
+ attrib = ", ".join([f'{key}="{val}"' for key, val in kwargs.items()])
+ print(f'"{fromName}" -> "{toName}" [{attrib}];', file=file)
+ else:
+ print(f'"{fromName}" -> "{toName}";', file=file)
+
+
+def _datasetRefId(dsRef: DatasetRef) -> str:
+ """Make an identifying string for given ref"""
+ dsId = [dsRef.datasetType.name]
+ dsId.extend(f"{key} = {dsRef.dataId[key]}" for key in sorted(dsRef.dataId.required.keys()))
+ return ":".join(dsId)
+
+
+def _makeDSNode(dsRef: DatasetRef, allDatasetRefs: dict[str, str], file: io.TextIOBase) -> str:
+ """Make new node for dataset if it does not exist.
+
+ Returns node name.
+ """
+ dsRefId = _datasetRefId(dsRef)
+ nodeName = allDatasetRefs.get(dsRefId)
+ if nodeName is None:
+ idx = len(allDatasetRefs)
+ nodeName = f"dsref_{idx}"
+ allDatasetRefs[dsRefId] = nodeName
+ _renderDSNode(nodeName, dsRef, file)
+ return nodeName
+
+
+# ------------------------
+# Exported definitions --
+# ------------------------
+
+
+def graph2dot(qgraph: QuantumGraph, file: Any) -> None:
+ """Convert QuantumGraph into GraphViz digraph.
+
+ This method is mostly for documentation/presentation purposes.
+
+ Parameters
+ ----------
+ qgraph : `lsst.pipe.base.QuantumGraph`
+ QuantumGraph instance.
+ file : `str` or file object
+ File where GraphViz graph (DOT language) is written, can be a file name
+ or file object.
+
+ Raises
+ ------
+ `OSError` is raised when output file cannot be open.
+ `ImportError` is raised when task class cannot be imported.
+ """
+ # open a file if needed
+ close = False
+ if not hasattr(file, "write"):
+ file = open(file, "w")
+ close = True
+
+ print("digraph QuantumGraph {", file=file)
+ _renderDefault("graph", _ATTRIBS["defaultGraph"], file)
+ _renderDefault("node", _ATTRIBS["defaultNode"], file)
+ _renderDefault("edge", _ATTRIBS["defaultEdge"], file)
+
+ allDatasetRefs: dict[str, str] = {}
+ for taskId, taskDef in enumerate(qgraph.taskGraph):
+ quanta = qgraph.getNodesForTask(taskDef)
+ for qId, quantumNode in enumerate(quanta):
+ # node for a task
+ taskNodeName = f"task_{taskId}_{qId}"
+ _renderQuantumNode(taskNodeName, taskDef, quantumNode, file)
+
+ # quantum inputs
+ for dsRefs in quantumNode.quantum.inputs.values():
+ for dsRef in dsRefs:
+ nodeName = _makeDSNode(dsRef, allDatasetRefs, file)
+ _renderEdge(nodeName, taskNodeName, file)
+
+ # quantum outputs
+ for dsRefs in quantumNode.quantum.outputs.values():
+ for dsRef in dsRefs:
+ nodeName = _makeDSNode(dsRef, allDatasetRefs, file)
+ _renderEdge(taskNodeName, nodeName, file)
+
+ print("}", file=file)
+ if close:
+ file.close()
+
+
+def pipeline2dot(pipeline: Pipeline | Iterable[TaskDef], file: Any) -> None:
+ """Convert `~lsst.pipe.base.Pipeline` into GraphViz digraph.
+
+ This method is mostly for documentation/presentation purposes.
+ Unlike other methods this method does not validate graph consistency.
+
+ Parameters
+ ----------
+ pipeline : `lsst.pipe.base.Pipeline`
+ Pipeline description.
+ file : `str` or file object
+ File where GraphViz graph (DOT language) is written, can be a file name
+ or file object.
+
+ Raises
+ ------
+ `OSError` is raised when output file cannot be open.
+ `ImportError` is raised when task class cannot be imported.
+ `MissingTaskFactoryError` is raised when TaskFactory is needed but not
+ provided.
+ """
+ universe = DimensionUniverse()
+
+ def expand_dimensions(connection: connectionTypes.BaseConnection) -> list[str]:
+ """Return expanded list of dimensions, with special skypix treatment.
+
+ Parameters
+ ----------
+ connection : `list` [`str`]
+ Connection to examine.
+
+ Returns
+ -------
+ dimensions : `list` [`str`]
+ Expanded list of dimensions.
+ """
+ dimension_set = set()
+ if isinstance(connection, connectionTypes.DimensionedConnection):
+ dimension_set = set(connection.dimensions)
+ skypix_dim = []
+ if "skypix" in dimension_set:
+ dimension_set.remove("skypix")
+ skypix_dim = ["skypix"]
+ dimensions = universe.conform(dimension_set)
+ return list(dimensions.names) + skypix_dim
+
+ # open a file if needed
+ close = False
+ if not hasattr(file, "write"):
+ file = open(file, "w")
+ close = True
+
+ print("digraph Pipeline {", file=file)
+ _renderDefault("graph", _ATTRIBS["defaultGraph"], file)
+ _renderDefault("node", _ATTRIBS["defaultNode"], file)
+ _renderDefault("edge", _ATTRIBS["defaultEdge"], file)
+
+ allDatasets: set[str | tuple[str, str]] = set()
+ if isinstance(pipeline, Pipeline):
+ # TODO: DM-40639 will rewrite this code and finish off the deprecation
+ # of toExpandedPipeline but for now use the compatibility API.
+ pipeline = pipeline.to_graph()._iter_task_defs()
+
+ # The next two lines are a workaround until DM-29658 at which time metadata
+ # connections should start working with the above code
+ labelToTaskName = {}
+ metadataNodesToLink = set()
+
+ for idx, taskDef in enumerate(sorted(pipeline, key=lambda x: x.label)):
+ # node for a task
+ taskNodeName = f"task{idx}"
+
+ # next line is workaround until DM-29658
+ labelToTaskName[taskDef.label] = taskNodeName
+
+ _renderTaskNode(taskNodeName, taskDef, file, None)
+
+ metadataRePattern = re.compile("^(.*)_metadata$")
+ for attr in sorted(iterConnections(taskDef.connections, "inputs"), key=lambda x: x.name):
+ if attr.name not in allDatasets:
+ dimensions = expand_dimensions(attr)
+ _renderDSTypeNode(attr.name, dimensions, file)
+ allDatasets.add(attr.name)
+ nodeName, component = DatasetType.splitDatasetTypeName(attr.name)
+ _renderEdge(attr.name, taskNodeName, file)
+ # connect component dataset types to the composite type that
+ # produced it
+ if component is not None and (nodeName, attr.name) not in allDatasets:
+ _renderEdge(nodeName, attr.name, file)
+ allDatasets.add((nodeName, attr.name))
+ if nodeName not in allDatasets:
+ dimensions = expand_dimensions(attr)
+ _renderDSTypeNode(nodeName, dimensions, file)
+ # The next if block is a workaround until DM-29658 at which time
+ # metadata connections should start working with the above code
+ if (match := metadataRePattern.match(attr.name)) is not None:
+ matchTaskLabel = match.group(1)
+ metadataNodesToLink.add((matchTaskLabel, attr.name))
+
+ for attr in sorted(iterConnections(taskDef.connections, "prerequisiteInputs"), key=lambda x: x.name):
+ if attr.name not in allDatasets:
+ dimensions = expand_dimensions(attr)
+ _renderDSTypeNode(attr.name, dimensions, file)
+ allDatasets.add(attr.name)
+ # use dashed line for prerequisite edges to distinguish them
+ _renderEdge(attr.name, taskNodeName, file, style="dashed")
+
+ for attr in sorted(iterConnections(taskDef.connections, "outputs"), key=lambda x: x.name):
+ if attr.name not in allDatasets:
+ dimensions = expand_dimensions(attr)
+ _renderDSTypeNode(attr.name, dimensions, file)
+ allDatasets.add(attr.name)
+ _renderEdge(taskNodeName, attr.name, file)
+
+ # This for loop is a workaround until DM-29658 at which time metadata
+ # connections should start working with the above code
+ for matchLabel, dsTypeName in metadataNodesToLink:
+ # only render an edge to metadata if the label is part of the current
+ # graph
+ if (result := labelToTaskName.get(matchLabel)) is not None:
+ _renderEdge(result, dsTypeName, file)
+
+ print("}", file=file)
+ if close:
+ file.close()
diff --git a/tests/test_dot_tools.py b/tests/test_dot_tools.py
new file mode 100644
index 000000000..e760d04ab
--- /dev/null
+++ b/tests/test_dot_tools.py
@@ -0,0 +1,199 @@
+# This file is part of pipe_base.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (https://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This software is dual licensed under the GNU General Public License and also
+# under a 3-clause BSD license. Recipients may choose which of these licenses
+# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
+# respectively. If you choose the GPL option then the following text applies
+# (but note that there is still no warranty even if you opt for BSD instead):
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+"""Simple unit test for Pipeline visualization.
+"""
+
+import io
+import re
+import unittest
+
+import lsst.pipe.base.connectionTypes as cT
+import lsst.utils.tests
+from lsst.pipe.base import Pipeline, PipelineTask, PipelineTaskConfig, PipelineTaskConnections
+from lsst.pipe.base.dot_tools import pipeline2dot
+
+
+class ExamplePipelineTaskConnections(PipelineTaskConnections, dimensions=()):
+ """Connections class used for testing.
+
+ Parameters
+ ----------
+ config : `PipelineTaskConfig`
+ The config to use for this connections class.
+ """
+
+ input1 = cT.Input(
+ name="", dimensions=["visit", "detector"], storageClass="example", doc="Input for this task"
+ )
+ input2 = cT.Input(
+ name="", dimensions=["visit", "detector"], storageClass="example", doc="Input for this task"
+ )
+ output1 = cT.Output(
+ name="", dimensions=["visit", "detector"], storageClass="example", doc="Output for this task"
+ )
+ output2 = cT.Output(
+ name="", dimensions=["visit", "detector"], storageClass="example", doc="Output for this task"
+ )
+
+ def __init__(self, *, config=None):
+ super().__init__(config=config)
+ if not config.connections.input2:
+ self.inputs.remove("input2")
+ if not config.connections.output2:
+ self.outputs.remove("output2")
+
+
+class ExamplePipelineTaskConfig(PipelineTaskConfig, pipelineConnections=ExamplePipelineTaskConnections):
+ """Example config used for testing."""
+
+
+def _makeConfig(inputName, outputName, pipeline, label):
+ """Add config overrides.
+
+ Factory method for config instances.
+
+ inputName and outputName can be either string or tuple of strings
+ with two items max.
+ """
+ if isinstance(inputName, tuple):
+ pipeline.addConfigOverride(label, "connections.input1", inputName[0])
+ pipeline.addConfigOverride(label, "connections.input2", inputName[1] if len(inputName) > 1 else "")
+ else:
+ pipeline.addConfigOverride(label, "connections.input1", inputName)
+
+ if isinstance(outputName, tuple):
+ pipeline.addConfigOverride(label, "connections.output1", outputName[0])
+ pipeline.addConfigOverride(label, "connections.output2", outputName[1] if len(outputName) > 1 else "")
+ else:
+ pipeline.addConfigOverride(label, "connections.output1", outputName)
+
+
+class ExamplePipelineTask(PipelineTask):
+ """Example pipeline task used for testing."""
+
+ ConfigClass = ExamplePipelineTaskConfig
+
+
+def _makePipeline(tasks):
+ """Generate Pipeline instance.
+
+ Parameters
+ ----------
+ tasks : list of tuples
+ Each tuple in the list has 3 or 4 items:
+ - input DatasetType name(s), string or tuple of strings
+ - output DatasetType name(s), string or tuple of strings
+ - task label, string
+ - optional task class object, can be None
+
+ Returns
+ -------
+ Pipeline instance
+ """
+ pipe = Pipeline("test pipeline")
+ for task in tasks:
+ inputs = task[0]
+ outputs = task[1]
+ label = task[2]
+ klass = task[3] if len(task) > 3 else ExamplePipelineTask
+ pipe.addTask(klass, label)
+ _makeConfig(inputs, outputs, pipe, label)
+ return list(pipe.to_graph()._iter_task_defs())
+
+
+class DotToolsTestCase(unittest.TestCase):
+ """A test case for dotTools."""
+
+ def test_pipeline2dot(self):
+ """Tests for dot_tools.pipeline2dot method."""
+ pipeline = _makePipeline(
+ [
+ ("A", ("B", "C"), "task0"),
+ ("C", "E", "task1"),
+ ("B", "D", "task2"),
+ (("D", "E"), "F", "task3"),
+ ("D.C", "G", "task4"),
+ ("task3_metadata", "H", "task5"),
+ ]
+ )
+ file = io.StringIO()
+ pipeline2dot(pipeline, file)
+
+ # It's hard to validate complete output, just checking few basic
+ # things, even that is not terribly stable.
+ lines = file.getvalue().strip().split("\n")
+ nglobals = 3
+ ndatasets = 10
+ ntasks = 6
+ nedges = 16
+ nextra = 2 # graph header and closing
+ self.assertEqual(len(lines), nglobals + ndatasets + ntasks + nedges + nextra)
+
+ # make sure that all node names are quoted
+ nodeRe = re.compile(r"^([^ ]+) \[.+\];$")
+ edgeRe = re.compile(r"^([^ ]+) *-> *([^ ]+);$")
+ for line in lines:
+ match = nodeRe.match(line)
+ if match:
+ node = match.group(1)
+ if node not in ["graph", "node", "edge"]:
+ self.assertEqual(node[0] + node[-1], '""')
+ continue
+ match = edgeRe.match(line)
+ if match:
+ for group in (1, 2):
+ node = match.group(group)
+ self.assertEqual(node[0] + node[-1], '""')
+ continue
+
+ # make sure components are connected appropriately
+ self.assertIn('"D" -> "D.C"', file.getvalue())
+
+ # make sure there is a connection created for metadata if someone
+ # tries to read it in
+ self.assertIn('"task3" -> "task3_metadata"', file.getvalue())
+
+
+class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
+ """Generic file handle leak check."""
+
+
+def setup_module(module):
+ """Set up the module for pytest.
+
+ Parameters
+ ----------
+ module : `~types.ModuleType`
+ Module to set up.
+ """
+ lsst.utils.tests.init()
+
+
+if __name__ == "__main__":
+ lsst.utils.tests.init()
+ unittest.main()
|