diff --git a/python/lsst/daf/butler/_dataset_association.py b/python/lsst/daf/butler/_dataset_association.py index a836a50682..6572fe0c38 100644 --- a/python/lsst/daf/butler/_dataset_association.py +++ b/python/lsst/daf/butler/_dataset_association.py @@ -29,12 +29,17 @@ __all__ = ("DatasetAssociation",) +from collections.abc import Iterator from dataclasses import dataclass -from typing import Any +from typing import TYPE_CHECKING, Any from ._dataset_ref import DatasetRef +from ._dataset_type import DatasetType from ._timespan import Timespan +if TYPE_CHECKING: + from .queries._general_query_results import GeneralQueryResults + @dataclass(frozen=True, eq=True) class DatasetAssociation: @@ -59,6 +64,26 @@ class DatasetAssociation: collection (`Timespan` or `None`). """ + @classmethod + def from_query_result( + cls, result: GeneralQueryResults, dataset_type: DatasetType + ) -> Iterator[DatasetAssociation]: + """Construct dataset associations from the result of general query. + + Parameters + ---------- + result : `GeneralQueryResults` + General query result returned by `Query.general` method. The result + has to include "{dataset_type.name}.timespan" and + "{dataset_type.name}.collection" columns. + dataset_type : `DatasetType` + Dataset type, query has to include this dataset type. + """ + timespan_key = f"{dataset_type.name}.timespan" + collection_key = f"{dataset_type.name}.collection" + for _, refs, row_dict in result.iter_tuples(dataset_type): + yield DatasetAssociation(refs[0], row_dict[collection_key], row_dict[timespan_key]) + def __lt__(self, other: Any) -> bool: # Allow sorting of associations if not isinstance(other, type(self)): diff --git a/python/lsst/daf/butler/registry/sql_registry.py b/python/lsst/daf/butler/registry/sql_registry.py index 990616743f..fb6f51503c 100644 --- a/python/lsst/daf/butler/registry/sql_registry.py +++ b/python/lsst/daf/butler/registry/sql_registry.py @@ -2432,10 +2432,7 @@ def queryDatasetAssociations( datasetType.dimensions, dataset_fields={datasetType.name: {"dataset_id", "run", "collection", "timespan"}}, ) - timespan_key = f"{datasetType.name}.timespan" - collection_key = f"{datasetType.name}.collection" - for _, refs, row_dict in result.iter_tuples(datasetType): - yield DatasetAssociation(refs[0], row_dict[collection_key], row_dict[timespan_key]) + yield from DatasetAssociation.from_query_result(result, datasetType) def get_datastore_records(self, ref: DatasetRef) -> DatasetRef: """Retrieve datastore records for given ref. diff --git a/python/lsst/daf/butler/remote_butler/_registry.py b/python/lsst/daf/butler/remote_butler/_registry.py index a74e6da454..bc7515a4c4 100644 --- a/python/lsst/daf/butler/remote_butler/_registry.py +++ b/python/lsst/daf/butler/remote_butler/_registry.py @@ -525,10 +525,7 @@ def queryDatasetAssociations( datasetType.dimensions, dataset_fields={datasetType.name: {"dataset_id", "run", "collection", "timespan"}}, ) - timespan_key = f"{datasetType.name}.timespan" - collection_key = f"{datasetType.name}.collection" - for _, refs, row_dict in result.iter_tuples(datasetType): - yield DatasetAssociation(refs[0], row_dict[collection_key], row_dict[timespan_key]) + yield from DatasetAssociation.from_query_result(result, datasetType) @property def storageClasses(self) -> StorageClassFactory: