Skip to content

Commit

Permalink
Merge pull request #1085 from lsst/tickets/DM-46347
Browse files Browse the repository at this point in the history
DM-46347: Fix issue where default data IDs were not being applied
  • Loading branch information
dhirving authored Oct 1, 2024
2 parents e8039eb + c3b4136 commit 907638f
Show file tree
Hide file tree
Showing 5 changed files with 310 additions and 123 deletions.
1 change: 1 addition & 0 deletions doc/changes/DM-46347.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed an issue where default data IDs were not constraining query results in the new query system.
35 changes: 18 additions & 17 deletions python/lsst/daf/butler/direct_query_driver/_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
from ..registry.managers import RegistryManagerInstances
from ..registry.wildcards import CollectionWildcard
from ._postprocessing import Postprocessing
from ._predicate_constraints_summary import PredicateConstraintsSummary
from ._query_builder import QueryBuilder, QueryJoiner
from ._query_plan import (
QueryFindFirstPlan,
Expand Down Expand Up @@ -927,24 +928,24 @@ def _analyze_query_tree(self, tree: qt.QueryTree) -> tuple[QueryJoinsPlan, Query
tree.get_joined_dimension_groups(),
calibration_dataset_types,
)
result = QueryJoinsPlan(predicate=predicate, columns=builder.columns)

# Extract the data ID implied by the predicate; we can use the governor
# dimensions in that to constrain the collections we search for
# datasets later.
predicate_constraints = PredicateConstraintsSummary(predicate)
# Use the default data ID to apply additional constraints where needed.
predicate_constraints.apply_default_data_id(self._default_data_id, tree.dimensions)
predicate = predicate_constraints.predicate

result = QueryJoinsPlan(
predicate=predicate,
columns=builder.columns,
messages=predicate_constraints.messages,
)

# Add columns required by postprocessing.
builder.postprocessing.gather_columns_required(result.columns)
# We also check that the predicate doesn't reference any dimensions
# without constraining their governor dimensions, since that's a
# particularly easy mistake to make and it's almost never intentional.
# We also allow the registry data ID values to provide governor values.
where_governors: set[str] = set()
result.predicate.gather_governors(where_governors)
for governor in where_governors:
if governor not in result.constraint_data_id and governor not in result.governors_referenced:
if governor in self._default_data_id.dimensions:
result.constraint_data_id[governor] = self._default_data_id[governor]
else:
raise InvalidQueryError(
f"Query 'where' expression references a dimension dependent on {governor} without "
"constraining it directly."
)

# Add materializations, which can also bring in more postprocessing.
for m_key, m_dimensions in tree.materializations.items():
m_state = self._materializations[m_key]
Expand All @@ -969,7 +970,7 @@ def _analyze_query_tree(self, tree: qt.QueryTree) -> tuple[QueryJoinsPlan, Query
resolved_dataset_search = self._resolve_dataset_search(
dataset_type_name,
dataset_search,
result.constraint_data_id,
predicate_constraints.constraint_data_id,
summaries_by_dataset_type[dataset_type_name],
)
result.datasets[dataset_type_name] = resolved_dataset_search
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
# This file is part of daf_butler.
#
# Developed for the LSST Data Management System.
# This product includes software developed by the LSST Project
# (http://www.lsst.org).
# See the COPYRIGHT file at the top-level directory of this distribution
# for details of code ownership.
#
# This software is dual licensed under the GNU General Public License and also
# under a 3-clause BSD license. Recipients may choose which of these licenses
# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
# respectively. If you choose the GPL option then the following text applies
# (but note that there is still no warranty even if you opt for BSD instead):
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

from __future__ import annotations

from typing import Any

from .._exceptions import InvalidQueryError
from ..dimensions import DataCoordinate, DataIdValue, DimensionGroup, DimensionUniverse
from ..queries import tree as qt
from ..queries.visitors import ColumnExpressionVisitor, PredicateVisitFlags, SimplePredicateVisitor


class PredicateConstraintsSummary:
"""Summarizes information about the constraints on data ID values implied
by a Predicate.
Parameters
----------
predicate : `Predicate`
Predicate to summarize.
"""

predicate: qt.Predicate
"""The predicate examined by this summary."""

constraint_data_id: dict[str, DataIdValue]
"""Data ID values that will be identical in all result rows due to query
constraints.
"""

messages: list[str]
"""Diagnostic messages that report reasons the query may not return any
rows.
"""

def __init__(self, predicate: qt.Predicate) -> None:
self.predicate = predicate
self.constraint_data_id = {}
self.messages = []
# Governor dimensions referenced directly in the predicate, but not
# necessarily constrained to the same value in all logic branches.
self._governors_referenced: set[str] = set()

self.predicate.visit(
_DataIdExtractionVisitor(self.constraint_data_id, self.messages, self._governors_referenced)
)

def apply_default_data_id(
self, default_data_id: DataCoordinate, query_dimensions: DimensionGroup
) -> None:
"""Augment the predicate and summary by adding missing constraints for
governor dimensions using a default data ID.
Parameters
----------
default_data_id : `DataCoordinate`
Data ID values that will be used to constrain the query if governor
dimensions have not already been constrained by the predicate.
query_dimensions : `DimensionGroup`
The set of dimensions returned in result rows from the query.
"""
# Find governor dimensions required by the predicate.
# If these are not constrained by the predicate or the default data ID,
# we will raise an exception.
where_governors: set[str] = set()
self.predicate.gather_governors(where_governors)

# Add in governor dimensions that are returned in result rows.
# We constrain these using a default data ID if one is available,
# but it's not an error to omit the constraint.
governors_used_by_query = where_governors | query_dimensions.governors

# For each governor dimension needed by the query, add a constraint
# from the default data ID if the existing predicate does not
# constrain it.
for governor in governors_used_by_query:
if governor not in self.constraint_data_id and governor not in self._governors_referenced:
if governor in default_data_id.dimensions:
data_id_value = default_data_id[governor]
self.constraint_data_id[governor] = data_id_value
self._governors_referenced.add(governor)
self.predicate = self.predicate.logical_and(
_create_data_id_predicate(governor, data_id_value, query_dimensions.universe)
)
elif governor in where_governors:
# Check that the predicate doesn't reference any dimensions
# without constraining their governor dimensions, since
# that's a particularly easy mistake to make and it's
# almost never intentional.
raise InvalidQueryError(
f"Query 'where' expression references a dimension dependent on {governor} without "
"constraining it directly."
)


def _create_data_id_predicate(
dimension_name: str, value: DataIdValue, universe: DimensionUniverse
) -> qt.Predicate:
"""Create a Predicate that tests whether the given dimension primary key is
equal to the given literal value.
"""
dimension = universe.dimensions[dimension_name]
return qt.Predicate.compare(
qt.DimensionKeyReference(dimension=dimension), "==", qt.make_column_literal(value)
)


class _DataIdExtractionVisitor(
SimplePredicateVisitor,
ColumnExpressionVisitor[tuple[str, None] | tuple[None, Any] | tuple[None, None]],
):
"""A column-expression visitor that extracts quality constraints on
dimensions that are not OR'd with anything else.
Parameters
----------
data_id : `dict`
Dictionary to populate in place.
messages : `list` [ `str` ]
List of diagnostic messages to populate in place.
governor_references : `set` [ `str` ]
Set of the names of governor dimension names that were referenced
directly. This includes dimensions that were constrained to different
values in different logic branches, and hence not included in
``data_id``.
"""

def __init__(self, data_id: dict[str, DataIdValue], messages: list[str], governor_references: set[str]):
self.data_id = data_id
self.messages = messages
self.governor_references = governor_references

def visit_comparison(
self,
a: qt.ColumnExpression,
operator: qt.ComparisonOperator,
b: qt.ColumnExpression,
flags: PredicateVisitFlags,
) -> None:
k_a, v_a = a.visit(self)
k_b, v_b = b.visit(self)
if flags & PredicateVisitFlags.HAS_OR_SIBLINGS:
return None
if flags & PredicateVisitFlags.INVERTED:
if operator == "!=":
operator = "=="
else:
return None
if operator != "==":
return None
if k_a is not None and v_b is not None:
key = k_a
value = v_b
elif k_b is not None and v_a is not None:
key = k_b
value = v_a
else:
return None
if (old := self.data_id.setdefault(key, value)) != value:
self.messages.append(f"'where' expression requires both {key}={value!r} and {key}={old!r}.")
return None

def visit_binary_expression(self, expression: qt.BinaryExpression) -> tuple[None, None]:
expression.a.visit(self)
expression.b.visit(self)
return None, None

def visit_unary_expression(self, expression: qt.UnaryExpression) -> tuple[None, None]:
expression.operand.visit(self)
return None, None

def visit_literal(self, expression: qt.ColumnLiteral) -> tuple[None, Any]:
return None, expression.get_literal_value()

def visit_dimension_key_reference(self, expression: qt.DimensionKeyReference) -> tuple[str, None]:
if expression.dimension.governor is expression.dimension:
self.governor_references.add(expression.dimension.name)
return expression.dimension.name, None

def visit_dimension_field_reference(self, expression: qt.DimensionFieldReference) -> tuple[None, None]:
if (
expression.element.governor is expression.element
and expression.field in expression.element.alternate_keys.names
):
self.governor_references.add(expression.element.name)
return None, None

def visit_dataset_field_reference(self, expression: qt.DatasetFieldReference) -> tuple[None, None]:
return None, None

def visit_reversed(self, expression: qt.Reversed) -> tuple[None, None]:
raise AssertionError("No Reversed expressions in predicates.")
Loading

0 comments on commit 907638f

Please sign in to comment.