From 90aadc18b9f6bcac43c740d0c5a72f35f0b1a067 Mon Sep 17 00:00:00 2001 From: Andy Salnikov Date: Mon, 5 Aug 2024 15:38:15 -0700 Subject: [PATCH] Initial implementation of a general query result (DM-45429) Registry method `queryDatasetAssociations` is reimplemented (for both direct and remote butler) to use new query system and new general query result class. --- .../butler/direct_butler/_direct_butler.py | 21 +-- .../daf/butler/direct_query_driver/_driver.py | 3 + .../_result_page_converter.py | 89 ++++++++++++- python/lsst/daf/butler/queries/__init__.py | 1 + .../butler/queries/_general_query_results.py | 125 ++++++++++++++++++ python/lsst/daf/butler/queries/_query.py | 40 +++++- python/lsst/daf/butler/queries/driver.py | 3 +- .../lsst/daf/butler/registry/sql_registry.py | 95 ++++++------- .../daf/butler/remote_butler/_registry.py | 16 ++- 9 files changed, 317 insertions(+), 76 deletions(-) create mode 100644 python/lsst/daf/butler/queries/_general_query_results.py diff --git a/python/lsst/daf/butler/direct_butler/_direct_butler.py b/python/lsst/daf/butler/direct_butler/_direct_butler.py index 68ddec7932..ac2bfb3df5 100644 --- a/python/lsst/daf/butler/direct_butler/_direct_butler.py +++ b/python/lsst/daf/butler/direct_butler/_direct_butler.py @@ -2195,32 +2195,19 @@ def dimensions(self) -> DimensionUniverse: # Docstring inherited. return self._registry.dimensions - @contextlib.contextmanager - def _query(self) -> Iterator[Query]: + def _query(self) -> contextlib.AbstractContextManager[Query]: # Docstring inherited. - with self._query_driver(self._registry.defaults.collections, self.registry.defaults.dataId) as driver: - yield Query(driver) + return self._registry._query() - @contextlib.contextmanager def _query_driver( self, default_collections: Iterable[str], default_data_id: DataCoordinate, - ) -> Iterator[DirectQueryDriver]: + ) -> contextlib.AbstractContextManager[DirectQueryDriver]: """Set up a QueryDriver instance for use with this Butler. Although this is marked as a private method, it is also used by Butler server. """ - with self._caching_context(): - driver = DirectQueryDriver( - self._registry._db, - self.dimensions, - self._registry._managers, - self._registry.dimension_record_cache, - default_collections=default_collections, - default_data_id=default_data_id, - ) - with driver: - yield driver + return self._registry._query_driver(default_collections, default_data_id) def _preload_cache(self) -> None: """Immediately load caches that are used for common operations.""" diff --git a/python/lsst/daf/butler/direct_query_driver/_driver.py b/python/lsst/daf/butler/direct_query_driver/_driver.py index 89ce006057..0aaa848b2c 100644 --- a/python/lsst/daf/butler/direct_query_driver/_driver.py +++ b/python/lsst/daf/butler/direct_query_driver/_driver.py @@ -79,6 +79,7 @@ DataCoordinateResultPageConverter, DatasetRefResultPageConverter, DimensionRecordResultPageConverter, + GeneralResultPageConverter, ResultPageConverter, ResultPageConverterContext, ) @@ -271,6 +272,8 @@ def _create_result_page_converter(self, spec: ResultSpec, builder: QueryBuilder) return DatasetRefResultPageConverter( spec, self.get_dataset_type(spec.dataset_type_name), context ) + case GeneralResultSpec(): + return GeneralResultPageConverter(spec, context) case _: raise NotImplementedError(f"Result type '{spec.result_type}' not yet implemented") diff --git a/python/lsst/daf/butler/direct_query_driver/_result_page_converter.py b/python/lsst/daf/butler/direct_query_driver/_result_page_converter.py index 22044994f1..f1419bcff4 100644 --- a/python/lsst/daf/butler/direct_query_driver/_result_page_converter.py +++ b/python/lsst/daf/butler/direct_query_driver/_result_page_converter.py @@ -30,7 +30,7 @@ from abc import abstractmethod from collections.abc import Iterable from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import sqlalchemy @@ -50,9 +50,16 @@ DataCoordinateResultPage, DatasetRefResultPage, DimensionRecordResultPage, + GeneralResultPage, ResultPage, ) -from ..queries.result_specs import DataCoordinateResultSpec, DatasetRefResultSpec, DimensionRecordResultSpec +from ..queries.result_specs import ( + DataCoordinateResultSpec, + DatasetRefResultSpec, + DimensionRecordResultSpec, + GeneralResultSpec, +) +from ..timespan_database_representation import TimespanDatabaseRepresentation if TYPE_CHECKING: from ..registry.interfaces import Database @@ -310,3 +317,81 @@ def convert(self, row: sqlalchemy.Row) -> dict[str, DimensionRecord]: # numpydo the dimensions in the database row. """ return {name: converter.convert(row) for name, converter in self._record_converters.items()} + + +class GeneralResultPageConverter(ResultPageConverter): # numpydoc ignore=PR01 + """Converts raw SQL rows into pages of `GeneralResult` query results.""" + + def __init__(self, spec: GeneralResultSpec, ctx: ResultPageConverterContext) -> None: + self.spec = spec + + result_columns = spec.get_result_columns() + self.converters: list[_GeneralColumnConverter] = [] + for column in result_columns: + column_name = qt.ColumnSet.get_qualified_name(column.logical_table, column.field) + if column.field == TimespanDatabaseRepresentation.NAME: + self.converters.append(_TimespanGeneralColumnConverter(column_name, ctx.db)) + else: + self.converters.append(_DefaultGeneralColumnConverter(column_name)) + + def convert(self, raw_rows: Iterable[sqlalchemy.Row]) -> GeneralResultPage: + rows = [tuple(cvt.convert(row) for cvt in self.converters) for row in raw_rows] + return GeneralResultPage(spec=self.spec, rows=rows) + + +class _GeneralColumnConverter: + """Interface for converting one or more columns in a result row to a single + column value in output row. + """ + + @abstractmethod + def convert(self, row: sqlalchemy.Row) -> Any: + """Convert one or more columns in the row into single value. + + Parameters + ---------- + row : `sqlalchemy.Row` + Row of values. + + Returns + ------- + value : `Any` + Result of the conversion. + """ + raise NotImplementedError() + + +class _DefaultGeneralColumnConverter(_GeneralColumnConverter): + """Converter that returns column value without conversion. + + Parameters + ---------- + name : `str` + Column name + """ + + def __init__(self, name: str): + self.name = name + + def convert(self, row: sqlalchemy.Row) -> Any: + return row._mapping[self.name] + + +class _TimespanGeneralColumnConverter(_GeneralColumnConverter): + """Converter that extracts timespan from the row. + + Parameters + ---------- + name : `str` + Column name or prefix. + db : `Database` + Database instance. + """ + + def __init__(self, name: str, db: Database): + self.timespan_class = db.getTimespanRepresentation() + self.name = name + + def convert(self, row: sqlalchemy.Row) -> Any: + timespan = self.timespan_class.extract(row._mapping, self.name) + return timespan diff --git a/python/lsst/daf/butler/queries/__init__.py b/python/lsst/daf/butler/queries/__init__.py index 15743f291f..720e4ca6d1 100644 --- a/python/lsst/daf/butler/queries/__init__.py +++ b/python/lsst/daf/butler/queries/__init__.py @@ -29,4 +29,5 @@ from ._data_coordinate_query_results import * from ._dataset_query_results import * from ._dimension_record_query_results import * +from ._general_query_results import * from ._query import * diff --git a/python/lsst/daf/butler/queries/_general_query_results.py b/python/lsst/daf/butler/queries/_general_query_results.py new file mode 100644 index 0000000000..ae99e6ecdf --- /dev/null +++ b/python/lsst/daf/butler/queries/_general_query_results.py @@ -0,0 +1,125 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ("GeneralQueryResults",) + +from collections.abc import Iterator +from typing import Any, final + +from .._dataset_ref import DatasetRef +from .._dataset_type import DatasetType +from ..dimensions import DataCoordinate, DimensionGroup +from ._base import QueryResultsBase +from .driver import QueryDriver +from .result_specs import GeneralResultSpec +from .tree import QueryTree, ResultColumn + + +@final +class GeneralQueryResults(QueryResultsBase): + """A query for `DatasetRef` results with a single dataset type. + + Parameters + ---------- + driver : `QueryDriver` + Implementation object that knows how to actually execute queries. + tree : `QueryTree` + Description of the query as a tree of joins and column expressions. The + instance returned directly by the `Butler._query` entry point should be + constructed via `make_unit_query_tree`. + spec : `GeneralResultSpec` + Specification of the query result rows, including output columns, + ordering, and slicing. + + Notes + ----- + This class should never be constructed directly by users; use `Query` + methods instead. + """ + + def __init__(self, driver: QueryDriver, tree: QueryTree, spec: GeneralResultSpec): + spec.validate_tree(tree) + super().__init__(driver, tree) + self._spec = spec + + def __iter__(self) -> Iterator[dict[ResultColumn, Any]]: + """Iterate over result rows. + + Yields + ------ + row_dict : `dict` [`ResultColumn`, `Any`] + Result row as dictionary, the keys are `ResultColumn` instances. + """ + for page in self._driver.execute(self._spec, self._tree): + columns = tuple(page.spec.get_result_columns()) + for row in page.rows: + yield dict(zip(columns, row)) + + def iter_refs(self, dataset_type: DatasetType) -> Iterator[tuple[DatasetRef, dict[ResultColumn, Any]]]: + """Iterate over result rows and return DatasetRef constructed from each + row and an original row. + + Parameters + ---------- + dataset_type : `DatasetType` + Type of the dataset to return. + + Yields + ------ + dataset_ref : `DatasetRef` + Dataset reference. + row_dict : `dict` [`ResultColumn`, `Any`] + Result row as dictionary, the keys are `ResultColumn` instances. + """ + dimensions = dataset_type.dimensions + id_key = ResultColumn(logical_table=dataset_type.name, field="dataset_id") + run_key = ResultColumn(logical_table=dataset_type.name, field="run") + data_id_keys = [ResultColumn(logical_table=element, field=None) for element in dimensions.required] + for row in self: + values = tuple(row[key] for key in data_id_keys) + data_id = DataCoordinate.from_required_values(dimensions, values) + ref = DatasetRef(dataset_type, data_id, row[run_key], id=row[id_key]) + yield ref, row + + @property + def dimensions(self) -> DimensionGroup: + # Docstring inherited + return self._spec.dimensions + + def count(self, *, exact: bool = True, discard: bool = False) -> int: + # Docstring inherited. + return self._driver.count(self._tree, self._spec, exact=exact, discard=discard) + + def _copy(self, tree: QueryTree, **kwargs: Any) -> GeneralQueryResults: + # Docstring inherited. + return GeneralQueryResults(self._driver, tree, self._spec.model_copy(update=kwargs)) + + def _get_datasets(self) -> frozenset[str]: + # Docstring inherited. + return frozenset(self._spec.dataset_fields) diff --git a/python/lsst/daf/butler/queries/_query.py b/python/lsst/daf/butler/queries/_query.py index dedb8ee36e..f9fe08470c 100644 --- a/python/lsst/daf/butler/queries/_query.py +++ b/python/lsst/daf/butler/queries/_query.py @@ -43,10 +43,16 @@ from ._data_coordinate_query_results import DataCoordinateQueryResults from ._dataset_query_results import DatasetRefQueryResults from ._dimension_record_query_results import DimensionRecordQueryResults +from ._general_query_results import GeneralQueryResults from .convert_args import convert_where_args from .driver import QueryDriver from .expression_factory import ExpressionFactory -from .result_specs import DataCoordinateResultSpec, DatasetRefResultSpec, DimensionRecordResultSpec +from .result_specs import ( + DataCoordinateResultSpec, + DatasetRefResultSpec, + DimensionRecordResultSpec, + GeneralResultSpec, +) from .tree import DatasetSearch, Predicate, QueryTree, make_identity_query_tree @@ -287,6 +293,38 @@ def dimension_records(self, element: str) -> DimensionRecordQueryResults: result_spec = DimensionRecordResultSpec(element=self._driver.universe[element]) return DimensionRecordQueryResults(self._driver, tree, result_spec) + def dataset_associations( + self, + dataset_type: DatasetType, + collections: Iterable[str], + ) -> GeneralQueryResults: + """Iterate over dataset-collection combinations where the dataset is in + the collection. + + Parameters + ---------- + dataset_type : `DatasetType` + A dataset type object. + collections : `~collections.abc.Iterable` [`str`] + Names of the collections to search. Chained collections are + ignored. + + Returns + ------- + result : `GeneralQueryResults` + Query result that can be iterated over. The result includes all + columns needed to construct `DatasetRef`, plus ``collection`` and + ``timespan`` columns. + """ + _, _, query = self._join_dataset_search_impl(dataset_type, collections) + result_spec = GeneralResultSpec( + dimensions=dataset_type.dimensions, + dimension_fields={}, + dataset_fields={dataset_type.name: {"dataset_id", "run", "collection", "timespan"}}, + find_first=False, + ) + return GeneralQueryResults(self._driver, tree=query._tree, spec=result_spec) + def materialize( self, *, diff --git a/python/lsst/daf/butler/queries/driver.py b/python/lsst/daf/butler/queries/driver.py index 6df857b218..f7a6af352a 100644 --- a/python/lsst/daf/butler/queries/driver.py +++ b/python/lsst/daf/butler/queries/driver.py @@ -116,7 +116,8 @@ class GeneralResultPage: spec: GeneralResultSpec - # Raw tabular data, with columns in the same order as spec.columns. + # Raw tabular data, with columns in the same order as + # spec.get_result_columns(). rows: list[tuple[Any, ...]] diff --git a/python/lsst/daf/butler/registry/sql_registry.py b/python/lsst/daf/butler/registry/sql_registry.py index 6e427e6fb2..f38f094c50 100644 --- a/python/lsst/daf/butler/registry/sql_registry.py +++ b/python/lsst/daf/butler/registry/sql_registry.py @@ -66,7 +66,10 @@ DimensionUniverse, ) from ..dimensions.record_cache import DimensionRecordCache +from ..direct_query_driver import DirectQueryDriver from ..progress import Progress +from ..queries import Query +from ..queries.tree import ResultColumn from ..registry import ( ArgumentError, CollectionExpressionError, @@ -2412,6 +2415,33 @@ def queryDimensionRecords( query = builder.finish().with_record_columns(element.name) return queries.DatabaseDimensionRecordQueryResults(query, element) + @contextlib.contextmanager + def _query(self) -> Iterator[Query]: + """Context manager returning a `Query` object used for construction + and execution of complex queries. + """ + with self._query_driver(self.defaults.collections, self.defaults.dataId) as driver: + yield Query(driver) + + @contextlib.contextmanager + def _query_driver( + self, + default_collections: Iterable[str], + default_data_id: DataCoordinate, + ) -> Iterator[DirectQueryDriver]: + """Set up a `QueryDriver` instance for query execution.""" + with self.caching_context(): + driver = DirectQueryDriver( + self._db, + self.dimensions, + self._managers, + self.dimension_record_cache, + default_collections=default_collections, + default_data_id=default_data_id, + ) + with driver: + yield driver + def queryDatasetAssociations( self, datasetType: str | DatasetType, @@ -2462,59 +2492,18 @@ def queryDatasetAssociations( lsst.daf.butler.registry.CollectionExpressionError Raised when ``collections`` expression is invalid. """ - if collections is None: - if not self.defaults.collections: - raise NoDefaultCollectionError( - "No collections provided to queryDatasetAssociations, " - "and no defaults from registry construction." - ) - collections = self.defaults.collections - collection_wildcard = CollectionWildcard.from_expression(collections) - backend = queries.SqlQueryBackend(self._db, self._managers, self.dimension_record_cache) - parent_dataset_type = backend.resolve_single_dataset_type_wildcard(datasetType) - timespan_tag = DatasetColumnTag(parent_dataset_type.name, "timespan") - collection_tag = DatasetColumnTag(parent_dataset_type.name, "collection") - for parent_collection_record in backend.resolve_collection_wildcard( - collection_wildcard, - collection_types=frozenset(collectionTypes), - flatten_chains=flattenChains, - ): - # Resolve this possibly-chained collection into a list of - # non-CHAINED collections that actually hold datasets of this - # type. - candidate_collection_records = backend.resolve_dataset_collections( - parent_dataset_type, - CollectionWildcard.from_names([parent_collection_record.name]), - allow_calibration_collections=True, - governor_constraints={}, - ) - if not candidate_collection_records: - continue - with backend.context() as context: - relation = backend.make_dataset_query_relation( - parent_dataset_type, - candidate_collection_records, - columns={"dataset_id", "run", "timespan", "collection"}, - context=context, - ) - reader = queries.DatasetRefReader( - parent_dataset_type, - translate_collection=lambda k: self._managers.collections[k].name, - full=False, - ) - for row in context.fetch_iterable(relation): - ref = reader.read(row) - collection_record = self._managers.collections[row[collection_tag]] - if collection_record.type is CollectionType.CALIBRATION: - timespan = row[timespan_tag] - else: - # For backwards compatibility and (possibly?) user - # convenience we continue to define the timespan of a - # DatasetAssociation row for a non-CALIBRATION - # collection to be None rather than a fully unbounded - # timespan. - timespan = None - yield DatasetAssociation(ref=ref, collection=collection_record.name, timespan=timespan) + if isinstance(datasetType, str): + datasetType = self.getDatasetType(datasetType) + resolved_collections = self.queryCollections( + collections, datasetType, collectionTypes=collectionTypes, flattenChains=flattenChains + ) + with self._query() as query: + result = query.dataset_associations(datasetType, resolved_collections) + timespan_key = ResultColumn(logical_table=datasetType.name, field="timespan") + collection_key = ResultColumn(logical_table=datasetType.name, field="collection") + for ref, row_dict in result.iter_refs(datasetType): + _LOG.debug("row_dict: %s", row_dict) + yield DatasetAssociation(ref, row_dict[collection_key], row_dict[timespan_key]) def get_datastore_records(self, ref: DatasetRef) -> DatasetRef: """Retrieve datastore records for given ref. diff --git a/python/lsst/daf/butler/remote_butler/_registry.py b/python/lsst/daf/butler/remote_butler/_registry.py index f7dee118bd..c5fd5a5721 100644 --- a/python/lsst/daf/butler/remote_butler/_registry.py +++ b/python/lsst/daf/butler/remote_butler/_registry.py @@ -46,6 +46,7 @@ DimensionRecord, DimensionUniverse, ) +from ..queries.tree import ResultColumn from ..registry import ( CollectionArgType, CollectionSummary, @@ -65,12 +66,12 @@ DimensionRecordQueryResults, ) from ..registry.wildcards import CollectionWildcard, DatasetTypeWildcard -from ..remote_butler import RemoteButler from ._collection_args import ( convert_collection_arg_to_glob_string_list, convert_dataset_type_arg_to_glob_string_list, ) from ._http_connection import RemoteButlerHttpConnection, parse_model +from ._remote_butler import RemoteButler from .registry._query_common import CommonQueryArguments from .registry._query_data_coordinates import QueryDriverDataCoordinateQueryResults from .registry._query_datasets import QueryDriverDatasetRefQueryResults @@ -513,7 +514,18 @@ def queryDatasetAssociations( collectionTypes: Iterable[CollectionType] = CollectionType.all(), flattenChains: bool = False, ) -> Iterator[DatasetAssociation]: - raise NotImplementedError() + # queryCollections only accepts DatasetType. + if isinstance(datasetType, str): + datasetType = self.getDatasetType(datasetType) + resolved_collections = self.queryCollections( + collections, datasetType=datasetType, collectionTypes=collectionTypes, flattenChains=flattenChains + ) + with self._butler._query() as query: + result = query.dataset_associations(datasetType, resolved_collections) + timespan_key = ResultColumn(logical_table=datasetType.name, field="timespan") + collection_key = ResultColumn(logical_table=datasetType.name, field="collection") + for ref, row_dict in result.iter_refs(datasetType): + yield DatasetAssociation(ref, row_dict[collection_key], row_dict[timespan_key]) @property def storageClasses(self) -> StorageClassFactory: