From c904bdfda3d3a2529b1ff9abc97b4fcdf80c7efb Mon Sep 17 00:00:00 2001 From: Jim Bosch Date: Fri, 30 Aug 2024 15:52:58 -0400 Subject: [PATCH] Support storage class conversions of components in PipelineGraph. This can leave the component DatasetType/Ref seen by the component-reading task with an unexpected parentStorageClass, but we hope the task won't care. --- doc/changes/DM-46064.feature.md | 1 + .../lsst/pipe/base/pipeline_graph/_edges.py | 28 +++++------ tests/test_pipeline_graph.py | 46 +++++++++++++++++-- 3 files changed, 57 insertions(+), 18 deletions(-) create mode 100644 doc/changes/DM-46064.feature.md diff --git a/doc/changes/DM-46064.feature.md b/doc/changes/DM-46064.feature.md new file mode 100644 index 000000000..4babd4d09 --- /dev/null +++ b/doc/changes/DM-46064.feature.md @@ -0,0 +1 @@ +Storage class conversions of component dataset types are now supported in pipelines. diff --git a/python/lsst/pipe/base/pipeline_graph/_edges.py b/python/lsst/pipe/base/pipeline_graph/_edges.py index 4e6e6807d..e8b16428e 100644 --- a/python/lsst/pipe/base/pipeline_graph/_edges.py +++ b/python/lsst/pipe/base/pipeline_graph/_edges.py @@ -32,7 +32,7 @@ from collections.abc import Callable, Mapping, Sequence from typing import Any, ClassVar, Self, TypeVar -from lsst.daf.butler import DatasetRef, DatasetType, DimensionUniverse +from lsst.daf.butler import DatasetRef, DatasetType, DimensionUniverse, StorageClassFactory from lsst.daf.butler.registry import MissingDatasetTypeError from lsst.utils.classes import immutable @@ -396,10 +396,7 @@ def diff(self: ReadEdge, other: ReadEdge, connection_type: str = "connection") - def adapt_dataset_type(self, dataset_type: DatasetType) -> DatasetType: # Docstring inherited. if self.component is not None: - assert ( - self.storage_class_name == dataset_type.storageClass.allComponents()[self.component].name - ), "components with storage class overrides are not supported" - return dataset_type.makeComponentDatasetType(self.component) + dataset_type = dataset_type.makeComponentDatasetType(self.component) if self.storage_class_name != dataset_type.storageClass_name: return dataset_type.overrideStorageClass(self.storage_class_name) return dataset_type @@ -407,10 +404,7 @@ def adapt_dataset_type(self, dataset_type: DatasetType) -> DatasetType: def adapt_dataset_ref(self, ref: DatasetRef) -> DatasetRef: # Docstring inherited. if self.component is not None: - assert ( - self.storage_class_name == ref.datasetType.storageClass.allComponents()[self.component].name - ), "components with storage class overrides are not supported" - return ref.makeComponentRef(self.component) + ref = ref.makeComponentRef(self.component) if self.storage_class_name != ref.datasetType.storageClass_name: return ref.overrideStorageClass(self.storage_class_name) return ref @@ -618,13 +612,21 @@ def report_current_origin() -> str: f"which does not include component {self.component!r} " f"as requested by task {self.task_label!r}." ) - if all_current_components[self.component].name != self.storage_class_name: + # Note that we can't actually make a fully-correct DatasetType + # for the component the task wants, because we don't have the + # parent storage class. + current_component = all_current_components[self.component] + if ( + current_component.name != self.storage_class_name + and not StorageClassFactory() + .getStorageClass(self.storage_class_name) + .can_convert(current_component) + ): raise IncompatibleDatasetTypeError( f"Dataset type '{self.parent_dataset_type_name}.{self.component}' has storage class " f"{all_current_components[self.component].name!r} " - f"(from {report_current_origin()}), which does not match " - f"{self.storage_class_name!r}, as requested by task {self.task_label!r}. " - "Note that storage class conversions of components are not supported." + f"(from {report_current_origin()}), which cannot be converted to " + f"{self.storage_class_name!r}, as requested by task {self.task_label!r}." ) return current, is_initial_query_constraint, is_prerequisite else: diff --git a/tests/test_pipeline_graph.py b/tests/test_pipeline_graph.py index 0399ac7df..d4935ccd9 100644 --- a/tests/test_pipeline_graph.py +++ b/tests/test_pipeline_graph.py @@ -1022,11 +1022,11 @@ def _have_example_storage_classes() -> bool: """Check whether some storage classes work as expected. Given that these have registered converters, it shouldn't actually be - necessary to import be able to those types in order to determine that - they're convertible, but the storage class machinery is implemented such - that types that can't be imported can't be converted, and while that's - inconvenient here it's totally fine in non-testing scenarios where you only - care about a storage class if you can actually use it. + necessary to import those types in order to determine that they're + convertible, but the storage class machinery is implemented such that types + that can't be imported can't be converted, and while that's inconvenient + here it's totally fine in non-testing scenarios where you only care about a + storage class if you can actually use it. """ getter = StorageClassFactory().getStorageClass return ( @@ -1416,6 +1416,42 @@ def test_component_resolved_by_output(self) -> None: self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref) self.assertEqual(graph.dataset_types["d"].generalize_ref(b_ref), ref) + @unittest.skipUnless( + _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available." + ) + def test_component_storage_class_converted(self) -> None: + """Test successful resolution of a component dataset type due to + an output connection referencing the parent dataset type, but with a + different (convertible) storage class. + """ + self.a_config.outputs["o"] = DynamicConnectionConfig(dataset_type_name="d", storage_class="DataFrame") + self.b_config.inputs["i"] = DynamicConnectionConfig( + dataset_type_name="d.schema", storage_class="ArrowSchema" + ) + graph = self.make_graph() + output_parent_dataset_type = DatasetType("d", self.dimensions.empty, get_mock_name("DataFrame")) + graph.resolve(MockRegistry(self.dimensions, {})) + self.assertEqual(graph.dataset_types["d"].dataset_type, output_parent_dataset_type) + a_o = graph.tasks["a"].outputs["o"] + b_i = graph.tasks["b"].inputs["i"] + self.assertEqual(b_i.dataset_type_name, "d.schema") + self.assertEqual(a_o.adapt_dataset_type(output_parent_dataset_type), output_parent_dataset_type) + self.assertEqual( + # We don't really want to compare the full dataset type here, + # because that's going to include a parentStorageClass that may or + # may not make sense. + b_i.adapt_dataset_type(output_parent_dataset_type).storageClass_name, + get_mock_name("ArrowSchema"), + ) + data_id = DataCoordinate.make_empty(self.dimensions) + ref = DatasetRef(output_parent_dataset_type, data_id, run="r") + a_ref = a_o.adapt_dataset_ref(ref) + b_ref = b_i.adapt_dataset_ref(ref) + self.assertEqual(a_ref, ref) + self.assertEqual(b_ref.datasetType.storageClass_name, get_mock_name("ArrowSchema")) + self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref) + self.assertEqual(graph.dataset_types["d"].generalize_ref(b_ref), ref) + @unittest.skipUnless( _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available." )